diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 046e8f0c57..ed3ab0375c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,7 +2,7 @@ /.github/workflows/ @lstein @blessedcoolant # documentation -/docs/ @lstein @tildebyte @blessedcoolant +/docs/ @lstein @blessedcoolant @hipsterusername /mkdocs.yml @lstein @blessedcoolant # nodes @@ -18,17 +18,17 @@ /invokeai/version @lstein @blessedcoolant # web ui -/invokeai/frontend @blessedcoolant @psychedelicious @lstein -/invokeai/backend @blessedcoolant @psychedelicious @lstein +/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp +/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp # generation, model management, postprocessing -/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 +/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779 # front ends /invokeai/frontend/CLI @lstein /invokeai/frontend/install @lstein @ebr -/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername -/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername -/invokeai/frontend/web @psychedelicious @blessedcoolant +/invokeai/frontend/merge @lstein @blessedcoolant +/invokeai/frontend/training @lstein @blessedcoolant +/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 17673de937..071232e06e 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -125,6 +125,7 @@ jobs: --no-nsfw_checker --precision=float32 --always_use_cpu + --use_memory_db --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }} --from_file ${{ env.TEST_PROMPTS }} diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 0615ff187e..920181ff8b 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,5 +1,6 @@ import io -from fastapi import HTTPException, Path, Query, Request, Response, UploadFile +from typing import Optional +from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.routing import APIRouter from fastapi.responses import FileResponse from PIL import Image @@ -7,7 +8,11 @@ from invokeai.app.models.image import ( ImageCategory, ImageType, ) -from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO +from invokeai.app.services.models.image_record import ( + ImageDTO, + ImageRecordChanges, + ImageUrlsDTO, +) from invokeai.app.services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"]) ) async def upload_image( file: UploadFile, - image_type: ImageType, request: Request, response: Response, - image_category: ImageCategory = ImageCategory.GENERAL, + image_category: ImageCategory = Query( + default=ImageCategory.GENERAL, description="The category of the image" + ), + is_intermediate: bool = Query( + default=False, description="Whether this is an intermediate image" + ), + session_id: Optional[str] = Query( + default=None, description="The session ID associated with this upload, if any" + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type.startswith("image"): @@ -46,9 +58,11 @@ async def upload_image( try: image_dto = ApiDependencies.invoker.services.images.create( - pil_image, - image_type, - image_category, + image=pil_image, + image_type=ImageType.UPLOAD, + image_category=image_category, + session_id=session_id, + is_intermediate=is_intermediate, ) response.status_code = 201 @@ -61,7 +75,7 @@ async def upload_image( @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Query(description="The type of image to delete"), + image_type: ImageType = Path(description="The type of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" @@ -73,6 +87,28 @@ async def delete_image( pass +@images_router.patch( + "/{image_type}/{image_name}", + operation_id="update_image", + response_model=ImageDTO, +) +async def update_image( + image_type: ImageType = Path(description="The type of image to update"), + image_name: str = Path(description="The name of the image to update"), + image_changes: ImageRecordChanges = Body( + description="The changes to apply to the image" + ), +) -> ImageDTO: + """Updates an image""" + + try: + return ApiDependencies.invoker.services.images.update( + image_type, image_name, image_changes + ) + except Exception as e: + raise HTTPException(status_code=400, detail="Failed to update image") + + @images_router.get( "/{image_type}/{image_name}/metadata", operation_id="get_image_metadata", @@ -85,9 +121,7 @@ async def get_image_metadata( """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto( - image_type, image_name - ) + return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) except Exception as e: raise HTTPException(status_code=404) @@ -113,9 +147,7 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name - ) + path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 073a8f569b..de543d2d85 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -13,10 +13,13 @@ from typing import ( from pydantic import BaseModel, ValidationError from pydantic.fields import Field +from invokeai.app.services.image_record_storage import SqliteImageRecordStorage +from invokeai.app.services.images import ImageService +from invokeai.app.services.metadata import CoreMetadataService +from invokeai.app.services.urls import LocalUrlService import invokeai.backend.util.logging as logger -from invokeai.app.services.metadata import PngMetadataService from .services.default_graphs import create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -188,6 +191,9 @@ def invoke_all(context: CliContext): raise SessionError() +logger = logger.InvokeAILogger.getLogger() + + def invoke_cli(): # this gets the basic configuration config = get_invokeai_config() @@ -206,24 +212,43 @@ def invoke_cli(): events = EventServiceBase() output_folder = config.output_path - metadata = PngMetadataService() # TODO: build a file/path manager? - db_location = os.path.join(output_folder, "invokeai.db") + if config.use_memory_db: + db_location = ":memory:" + else: + db_location = os.path.join(output_folder, "invokeai.db") + + logger.info(f'InvokeAI database location is "{db_location}"') + + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=db_location, table_name="graph_executions" + ) + + urls = LocalUrlService() + metadata = CoreMetadataService() + image_record_storage = SqliteImageRecordStorage(db_location) + image_file_storage = DiskImageFileStorage(f"{output_folder}/images") + + images = ImageService( + image_record_storage=image_record_storage, + image_file_storage=image_file_storage, + metadata=metadata, + url=urls, + logger=logger, + graph_execution_manager=graph_execution_manager, + ) services = InvocationServices( model_manager=model_manager, events=events, latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), - images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata), - metadata=metadata, + images=images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" ), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), restoration=RestorationServices(config,logger=logger), logger=logger, diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index da61641105..4ce3e839b6 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel): #fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") + is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") #fmt: on @@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False): "image", "latents", "model", + "control", ], ] tags: List[str] diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 475b6028a9..891f217317 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput): # Outputs collection: list[int] = Field(default=[], description="The int collection") +class FloatCollectionOutput(BaseInvocationOutput): + """A collection of floats""" + + type: Literal["float_collection"] = "float_collection" + + # Outputs + collection: list[float] = Field(default=[], description="The float collection") + class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py new file mode 100644 index 0000000000..187784b29e --- /dev/null +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -0,0 +1,428 @@ +# InvokeAI nodes for ControlNet image preprocessors +# initial implementation by Gregg Helt, 2023 +# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux + +import numpy as np +from typing import Literal, Optional, Union, List +from PIL import Image, ImageFilter, ImageOps +from pydantic import BaseModel, Field + +from ..models.image import ImageField, ImageType, ImageCategory +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, + InvocationConfig, +) +from controlnet_aux import ( + CannyDetector, + HEDdetector, + LineartDetector, + LineartAnimeDetector, + MidasDetector, + MLSDdetector, + NormalBaeDetector, + OpenposeDetector, + PidiNetDetector, + ContentShuffleDetector, + ZoeDetector, + MediapipeFaceDetector, +) + +from .image import ImageOutput, PILInvocationConfig + +CONTROLNET_DEFAULT_MODELS = [ + ########################################### + # lllyasviel sd v1.5, ControlNet v1.0 models + ############################################## + "lllyasviel/sd-controlnet-canny", + "lllyasviel/sd-controlnet-depth", + "lllyasviel/sd-controlnet-hed", + "lllyasviel/sd-controlnet-seg", + "lllyasviel/sd-controlnet-openpose", + "lllyasviel/sd-controlnet-scribble", + "lllyasviel/sd-controlnet-normal", + "lllyasviel/sd-controlnet-mlsd", + + ############################################# + # lllyasviel sd v1.5, ControlNet v1.1 models + ############################################# + "lllyasviel/control_v11p_sd15_canny", + "lllyasviel/control_v11p_sd15_openpose", + "lllyasviel/control_v11p_sd15_seg", + # "lllyasviel/control_v11p_sd15_depth", # broken + "lllyasviel/control_v11f1p_sd15_depth", + "lllyasviel/control_v11p_sd15_normalbae", + "lllyasviel/control_v11p_sd15_scribble", + "lllyasviel/control_v11p_sd15_mlsd", + "lllyasviel/control_v11p_sd15_softedge", + "lllyasviel/control_v11p_sd15s2_lineart_anime", + "lllyasviel/control_v11p_sd15_lineart", + "lllyasviel/control_v11p_sd15_inpaint", + # "lllyasviel/control_v11u_sd15_tile", + # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile", + # so for now replace "lllyasviel/control_v11f1e_sd15_tile", + "lllyasviel/control_v11e_sd15_shuffle", + "lllyasviel/control_v11e_sd15_ip2p", + "lllyasviel/control_v11f1e_sd15_tile", + + ################################################# + # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? + ################################################## + "thibaud/controlnet-sd21-openpose-diffusers", + "thibaud/controlnet-sd21-canny-diffusers", + "thibaud/controlnet-sd21-depth-diffusers", + "thibaud/controlnet-sd21-scribble-diffusers", + "thibaud/controlnet-sd21-hed-diffusers", + "thibaud/controlnet-sd21-zoedepth-diffusers", + "thibaud/controlnet-sd21-color-diffusers", + "thibaud/controlnet-sd21-openposev2-diffusers", + "thibaud/controlnet-sd21-lineart-diffusers", + "thibaud/controlnet-sd21-normalbae-diffusers", + "thibaud/controlnet-sd21-ade20k-diffusers", + + ############################################## + # ControlNetMediaPipeface, ControlNet v1.1 + ############################################## + # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5 + # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg + # hacked t2l to split to model & subfolder if format is "model,subfolder" + "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5 + "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1? +] + +CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] + +class ControlField(BaseModel): + image: ImageField = Field(default=None, description="processed image") + control_model: Optional[str] = Field(default=None, description="control model used") + control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") + + class Config: + schema_extra = { + "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] + } + + +class ControlOutput(BaseInvocationOutput): + """node output for ControlNet info""" + # fmt: off + type: Literal["control_output"] = "control_output" + control: ControlField = Field(default=None, description="The control info dict") + # fmt: on + + +class ControlNetInvocation(BaseInvocation): + """Collects ControlNet info to pass to other nodes""" + # fmt: off + type: Literal["controlnet"] = "controlnet" + # Inputs + image: ImageField = Field(default=None, description="image to process") + control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + description="control model used") + control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet") + # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") + # fmt: on + + + def invoke(self, context: InvocationContext) -> ControlOutput: + + return ControlOutput( + control=ControlField( + image=self.image, + control_model=self.control_model, + control_weight=self.control_weight, + begin_step_percent=self.begin_step_percent, + end_step_percent=self.end_step_percent, + ), + ) + +# TODO: move image processors to separate file (image_analysis.py +class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): + """Base class for invocations that preprocess images for ControlNet""" + + # fmt: off + type: Literal["image_processor"] = "image_processor" + # Inputs + image: ImageField = Field(default=None, description="image to process") + # fmt: on + + + def run_processor(self, image): + # superclass just passes through image without processing + return image + + def invoke(self, context: InvocationContext) -> ImageOutput: + + raw_image = context.services.images.get_pil_image( + self.image.image_type, self.image.image_name + ) + # image type should be PIL.PngImagePlugin.PngImageFile ? + processed_image = self.run_processor(raw_image) + + # FIXME: what happened to image metadata? + # metadata = context.services.metadata.build_metadata( + # session_id=context.graph_execution_state_id, node=self + # ) + + # currently can't see processed image in node UI without a showImage node, + # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery + image_dto = context.services.images.create( + image=processed_image, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + session_id=context.graph_execution_state_id, + node_id=self.id, + is_intermediate=self.is_intermediate + ) + + """Builds an ImageOutput and its ImageField""" + processed_image_field = ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ) + return ImageOutput( + image=processed_image_field, + # width=processed_image.width, + width = image_dto.width, + # height=processed_image.height, + height = image_dto.height, + # mode=processed_image.mode, + ) + + +class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Canny edge detection for ControlNet""" + # fmt: off + type: Literal["canny_image_processor"] = "canny_image_processor" + # Input + low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") + high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") + # fmt: on + + def run_processor(self, image): + canny_processor = CannyDetector() + processed_image = canny_processor(image, self.low_threshold, self.high_threshold) + return processed_image + + +class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies HED edge detection to image""" + # fmt: off + type: Literal["hed_image_processor"] = "hed_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # safe not supported in controlnet_aux v0.0.3 + # safe: bool = Field(default=False, description="whether to use safe mode") + scribble: bool = Field(default=False, description="whether to use scribble mode") + # fmt: on + + def run_processor(self, image): + hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") + processed_image = hed_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + # safe not supported in controlnet_aux v0.0.3 + # safe=self.safe, + scribble=self.scribble, + ) + return processed_image + + +class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies line art processing to image""" + # fmt: off + type: Literal["lineart_image_processor"] = "lineart_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + coarse: bool = Field(default=False, description="whether to use coarse mode") + # fmt: on + + def run_processor(self, image): + lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") + processed_image = lineart_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + coarse=self.coarse) + return processed_image + + +class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies line art anime processing to image""" + # fmt: off + type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + ) + return processed_image + + +class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Openpose processing to image""" + # fmt: off + type: Literal["openpose_image_processor"] = "openpose_image_processor" + # Inputs + hand_and_face: bool = Field(default=False, description="whether to use hands and face mode") + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = openpose_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + hand_and_face=self.hand_and_face, + ) + return processed_image + + +class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Midas depth processing to image""" + # fmt: off + type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" + # Inputs + a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI") + bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th") + # depth_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") + # fmt: on + + def run_processor(self, image): + midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") + processed_image = midas_processor(image, + a=np.pi * self.a_mult, + bg_th=self.bg_th, + # dept_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal=self.depth_and_normal, + ) + return processed_image + + +class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies NormalBae processing to image""" + # fmt: off + type: Literal["normalbae_image_processor"] = "normalbae_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = normalbae_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution) + return processed_image + + +class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies MLSD processing to image""" + # fmt: off + type: Literal["mlsd_image_processor"] = "mlsd_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v") + thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d") + # fmt: on + + def run_processor(self, image): + mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") + processed_image = mlsd_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + thr_v=self.thr_v, + thr_d=self.thr_d) + return processed_image + + +class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies PIDI processing to image""" + # fmt: off + type: Literal["pidi_image_processor"] = "pidi_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + safe: bool = Field(default=False, description="whether to use safe mode") + scribble: bool = Field(default=False, description="whether to use scribble mode") + # fmt: on + + def run_processor(self, image): + pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") + processed_image = pidi_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + safe=self.safe, + scribble=self.scribble) + return processed_image + + +class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies content shuffle processing to image""" + # fmt: off + type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter") + w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter") + f: Union[int | None] = Field(default=256, ge=0, description="cont") + # fmt: on + + def run_processor(self, image): + content_shuffle_processor = ContentShuffleDetector() + processed_image = content_shuffle_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + h=self.h, + w=self.w, + f=self.f + ) + return processed_image + + +# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 +class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Zoe depth processing to image""" + # fmt: off + type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" + # fmt: on + + def run_processor(self, image): + zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = zoe_depth_processor(image) + return processed_image + + +class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies mediapipe face processing to image""" + # fmt: off + type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" + # Inputs + max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect") + min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection") + # fmt: on + + def run_processor(self, image): + mediapipe_face_processor = MediapipeFaceDetector() + processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) + return processed_image diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 26e06a2af8..5e9fe088b5 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -57,10 +57,11 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): image_dto = context.services.images.create( image=image_inpainted, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index aa16243093..0385c6a9f0 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -4,7 +4,9 @@ from functools import partial from typing import Literal, Optional, Union, get_args import numpy as np +from diffusers import ControlNetModel from torch import Tensor +import torch from pydantic import BaseModel, Field @@ -56,8 +58,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) + scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + control_model: Optional[str] = Field(default=None, description="The control model to use") + control_image: Optional[ImageField] = Field(default=None, description="The processed control image") # fmt: on # TODO: pass this an emitter method or something? or a session for dispatching? @@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # Handle invalid model parameter model = choose_model(context.services.model_manager, self.model) + # loading controlnet image (currently requires pre-processed image) + control_image = ( + None if self.control_image is None + else context.services.images.get( + self.control_image.image_type, self.control_image.image_name + ) + ) + # loading controlnet model + if (self.control_model is None or self.control_model==''): + control_model = None + else: + # FIXME: change this to dropdown menu? + # FIXME: generalize so don't have to hardcode torch_dtype and device + control_model = ControlNetModel.from_pretrained(self.control_model, + torch_dtype=torch.float16).to("cuda") + # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( context.graph_execution_state_id ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - outputs = Txt2Img(model).generate( + txt2img = Txt2Img(model, control_model=control_model) + outputs = txt2img.generate( prompt=self.prompt, step_callback=partial(self.dispatch_progress, context, source_node_id), + control_image=control_image, **self.dict( - exclude={"prompt"} + exclude={"prompt", "control_image" } ), # Shorthand for passing all of the parameters above manually ) # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object @@ -101,6 +124,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -181,6 +205,7 @@ class ImageToImageInvocation(TextToImageInvocation): image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -296,6 +321,7 @@ class InpaintInvocation(ImageToImageInvocation): image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 21dfb4c1cd..69d51e6158 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -143,6 +143,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -204,6 +205,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -242,6 +244,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return MaskOutput( @@ -280,6 +283,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -318,6 +322,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -356,6 +361,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -397,6 +403,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -437,6 +444,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -482,6 +490,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 17a43dbdac..ad60b62633 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -149,6 +149,7 @@ class InfillColorInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -193,6 +194,7 @@ class InfillTileInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( @@ -230,6 +232,7 @@ class InfillPatchMatchInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 34da76d39a..4975b7b578 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,8 +1,11 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import random -from typing import Literal, Optional, Union import einops +from typing import Literal, Optional, Union, List + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from pydantic import BaseModel, Field, validator import torch @@ -11,14 +14,18 @@ from invokeai.app.models.image import ImageCategory from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback +from .controlnet_image_processors import ControlField from ...backend.model_management.model_manager import ModelManager from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding from ...backend.prompting.conditioning import get_uc_and_c_and_ec + from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP +from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData + from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_file_storage import ImageType @@ -28,7 +35,7 @@ from .compel import ConditioningField from ...backend.stable_diffusion import PipelineIntermediateState from diffusers.schedulers import SchedulerMixin as Scheduler import diffusers -from diffusers import DiffusionPipeline +from diffusers import DiffusionPipeline, ControlNetModel class LatentsField(BaseModel): @@ -84,13 +91,13 @@ SAMPLER_NAME_VALUES = Literal[ def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - + scheduler_config = model.scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -167,8 +174,9 @@ class TextToLatentsInvocation(BaseInvocation): noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) + scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") + control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @@ -179,7 +187,8 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents", "image"], "type_hints": { - "model": "model" + "model": "model", + "control": "control", } }, } @@ -238,6 +247,81 @@ class TextToLatentsInvocation(BaseInvocation): ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta) return conditioning_data + def prep_control_data(self, + context: InvocationContext, + model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + control_input: List[ControlField], + latents_shape: List[int], + do_classifier_free_guidance: bool = True, + ) -> List[ControlNetData]: + # assuming fixed dimensional scaling of 8:1 for image:latents + control_height_resize = latents_shape[2] * 8 + control_width_resize = latents_shape[3] * 8 + if control_input is None: + # print("control input is None") + control_list = None + elif isinstance(control_input, list) and len(control_input) == 0: + # print("control input is empty list") + control_list = None + elif isinstance(control_input, ControlField): + # print("control input is ControlField") + control_list = [control_input] + elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): + # print("control input is list[ControlField]") + control_list = control_input + else: + # print("input control is unrecognized:", type(self.control)) + control_list = None + if (control_list is None): + control_data = None + # from above handling, any control that is not None should now be of type list[ControlField] + else: + # FIXME: add checks to skip entry if model or image is None + # and if weight is None, populate with default 1.0? + control_data = [] + control_models = [] + for control_info in control_list: + # handle control models + if ("," in control_info.control_model): + control_model_split = control_info.control_model.split(",") + control_name = control_model_split[0] + control_subfolder = control_model_split[1] + print("Using HF model subfolders") + print(" control_name: ", control_name) + print(" control_subfolder: ", control_subfolder) + control_model = ControlNetModel.from_pretrained(control_name, + subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to(model.device) + else: + control_model = ControlNetModel.from_pretrained(control_info.control_model, + torch_dtype=model.unet.dtype).to(model.device) + control_models.append(control_model) + control_image_field = control_info.image + input_image = context.services.images.get_pil_image(control_image_field.image_type, + control_image_field.image_name) + # self.image.image_type, self.image.image_name + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) + control_image = model.prepare_control_image( + image=input_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, + ) + control_item = ControlNetData(model=control_model, + image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent) + control_data.append(control_item) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + return control_data def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) @@ -252,14 +336,19 @@ class TextToLatentsInvocation(BaseInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) - # TODO: Verify the noise is the right size + print("type of control input: ", type(self.control)) + control_data = self.prep_control_data(model=model, context=context, control_input=self.control, + latents_shape=noise.shape, + do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), noise=noise, num_inference_steps=self.steps, conditioning_data=conditioning_data, - callback=step_callback + control_data=control_data, # list[ControlNetData] + callback=step_callback, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -285,7 +374,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model" + "model": "model", + "control": "control", } }, } @@ -304,6 +394,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) + print("type of control input: ", type(self.control)) + control_data = self.prep_control_data(model=model, context=context, control_input=self.control, + latents_shape=noise.shape, + do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( @@ -318,6 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): noise=noise, num_inference_steps=self.steps, conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] callback=step_callback ) @@ -362,14 +458,21 @@ class LatentsToImageInvocation(BaseInvocation): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] + # what happened to metadata? + # metadata = context.services.metadata.build_metadata( + # session_id=context.graph_execution_state_id, node=self + torch.cuda.empty_cache() + # new (post Image service refactor) way of using services to save image + # and gnenerate unique image_name image_dto = context.services.images.create( image=image, image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate ) return ImageOutput( @@ -413,6 +516,7 @@ class ResizeLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) context.services.latents.save(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -443,6 +547,7 @@ class ScaleLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) context.services.latents.save(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -467,6 +572,9 @@ class ImageToLatentsInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: + # image = context.services.images.get( + # self.image.image_type, self.image.image_name + # ) image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -487,6 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation): ) name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, latents) context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) - diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 2ce58c016b..113b630200 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput): # fmt: on +class FloatOutput(BaseInvocationOutput): + """A float output""" + + # fmt: off + type: Literal["float_output"] = "float_output" + param: float = Field(default=None, description="The output float") + # fmt: on + + class AddInvocation(BaseInvocation, MathInvocationConfig): """Adds two numbers""" diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py index fcc7f1737a..1c6297665b 100644 --- a/invokeai/app/invocations/params.py +++ b/invokeai/app/invocations/params.py @@ -3,7 +3,7 @@ from typing import Literal from pydantic import Field from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from .math import IntOutput +from .math import IntOutput, FloatOutput # Pass-through parameter nodes - used by subgraphs @@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a) + +class ParamFloatInvocation(BaseInvocation): + """A float parameter""" + #fmt: off + type: Literal["param_float"] = "param_float" + param: float = Field(default=0.0, description="The float value") + #fmt: on + + def invoke(self, context: InvocationContext) -> FloatOutput: + return FloatOutput(param=self.param) diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 024134cd46..db71e4201d 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -43,10 +43,11 @@ class RestoreFaceInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 75aeec784f..90c9e4bf4f 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -49,6 +49,7 @@ class UpscaleInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 544951ea34..46b50145aa 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum): RESULT = "results" UPLOAD = "uploads" - INTERMEDIATE = "intermediates" class InvalidImageTypeException(ValueError): diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index 2d87125744..49e0b6bed4 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -353,6 +353,7 @@ setting environment variables INVOKEAI_. sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') + root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths') conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') @@ -362,6 +363,7 @@ setting environment variables INVOKEAI_. lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths') outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') + use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models') @@ -511,7 +513,7 @@ class PagingArgumentParser(argparse.ArgumentParser): text = self.format_help() pydoc.pager(text) -def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings: +def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig: ''' This returns a singleton InvokeAIAppConfig configuration object. ''' diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 44688ada0a..60e196faa1 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any: node_input_field = node_inputs.get(field) or None return node_input_field +from typing import Optional, Union, List, get_args + +def is_union_subtype(t1, t2): + t1_args = get_args(t1) + t2_args = get_args(t2) + + if not t1_args: + # t1 is a single type + return t1 in t2_args + else: + # t1 is a Union, check that all of its types are in t2_args + return all(arg in t2_args for arg in t1_args) + +def is_list_or_contains_list(t): + t_args = get_args(t) + + # If the type is a List + if get_origin(t) is list: + return True + + # If the type is a Union + elif t_args: + # Check if any of the types in the Union is a List + for arg in t_args: + if get_origin(arg) is list: + return True + + return False + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: @@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if to_type in get_args(from_type): return True - if not issubclass(from_type, to_type): + # if not issubclass(from_type, to_type): + if not is_union_subtype(from_type, to_type): return False else: return False @@ -694,7 +724,11 @@ class Graph(BaseModel): input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists - if not all((get_origin(f) == list for f in output_fields)): + # if not all((get_origin(f) == list for f in output_fields)): + # return False + + # Verify that all outputs are lists + if not all(is_list_or_contains_list(f) for f in output_fields): return False # Verify that all outputs match the input type (are a base class or the same class) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 4e1f73978b..188a411a6b 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -12,6 +12,7 @@ from invokeai.app.models.image import ( ) from invokeai.app.services.models.image_record import ( ImageRecord, + ImageRecordChanges, deserialize_image_record, ) from invokeai.app.services.item_storage import PaginatedResults @@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC): """Gets an image record.""" pass + @abstractmethod + def update( + self, + image_name: str, + image_type: ImageType, + changes: ImageRecordChanges, + ) -> None: + """Updates an image record.""" + pass + @abstractmethod def get_many( self, @@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC): session_id: Optional[str], node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: """Saves an image record.""" pass @@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, + is_intermediate BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def update( + self, + image_name: str, + image_type: ImageType, + changes: ImageRecordChanges, + ) -> None: + try: + self._lock.acquire() + # Change the category of the image + if changes.image_category is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET image_category = ? + WHERE image_name = ?; + """, + (changes.image_category, image_name), + ) + + # Change the session associated with the image + if changes.session_id is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET session_id = ? + WHERE image_name = ?; + """, + (changes.session_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordSaveException from e + finally: + self._lock.release() + def get_many( self, image_type: ImageType, @@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): height: int, node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: try: metadata_json = ( @@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): height, node_id, session_id, - metadata + metadata, + is_intermediate ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, @@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id, session_id, metadata_json, + is_intermediate, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 914dd3b6d3..d0f7236fe2 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import ( from invokeai.app.services.models.image_record import ( ImageRecord, ImageDTO, + ImageRecordChanges, image_record_to_dto, ) from invokeai.app.services.image_file_storage import ( @@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import ( from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.urls import UrlServiceBase -from invokeai.app.util.misc import get_iso_timestamp if TYPE_CHECKING: from invokeai.app.services.graph import GraphExecutionState @@ -48,11 +48,21 @@ class ImageServiceABC(ABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - metadata: Optional[ImageMetadata] = None, + intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass + @abstractmethod + def update( + self, + image_type: ImageType, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: + """Updates an image.""" + pass + @abstractmethod def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" @@ -157,6 +167,7 @@ class ImageService(ImageServiceABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, + is_intermediate: bool = False, ) -> ImageDTO: if image_type not in ImageType: raise InvalidImageTypeException @@ -184,6 +195,8 @@ class ImageService(ImageServiceABC): image_category=image_category, width=width, height=height, + # Meta fields + is_intermediate=is_intermediate, # Nullable fields node_id=node_id, session_id=session_id, @@ -217,6 +230,7 @@ class ImageService(ImageServiceABC): created_at=created_at, updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, + is_intermediate=is_intermediate, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -231,6 +245,23 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem saving image record and file") raise e + def update( + self, + image_type: ImageType, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: + try: + self._services.records.update(image_name, image_type, changes) + return self.get_dto(image_type, image_name) + except ImageRecordSaveException: + self._services.logger.error("Failed to update image record") + raise + except Exception as e: + self._services.logger.error("Problem updating image record") + raise e + + def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: try: return self._services.files.get(image_type, image_name) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index bcbe95a41f..1f910253e5 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -1,18 +1,17 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +from __future__ import annotations from typing import TYPE_CHECKING -from logging import Logger - -from invokeai.app.services.images import ImageService -from invokeai.backend import ModelManager -from .events import EventServiceBase -from .latent_storage import LatentsStorageBase -from .restoration_services import RestorationServices -from .invocation_queue import InvocationQueueABC -from .item_storage import ItemStorageABC -from .config import InvokeAISettings - if TYPE_CHECKING: + from logging import Logger + from invokeai.app.services.images import ImageService + from invokeai.backend import ModelManager + from invokeai.app.services.events import EventServiceBase + from invokeai.app.services.latent_storage import LatentsStorageBase + from invokeai.app.services.restoration_services import RestorationServices + from invokeai.app.services.invocation_queue import InvocationQueueABC + from invokeai.app.services.item_storage import ItemStorageABC + from invokeai.app.services.config import InvokeAISettings from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.invoker import InvocationProcessorABC @@ -20,32 +19,33 @@ if TYPE_CHECKING: class InvocationServices: """Services that can be used by invocations""" - events: EventServiceBase - latents: LatentsStorageBase - queue: InvocationQueueABC - model_manager: ModelManager - restoration: RestorationServices - configuration: InvokeAISettings - images: ImageService + # TODO: Just forward-declared everything due to circular dependencies. Fix structure. + events: "EventServiceBase" + latents: "LatentsStorageBase" + queue: "InvocationQueueABC" + model_manager: "ModelManager" + restoration: "RestorationServices" + configuration: "InvokeAISettings" + images: "ImageService" # NOTE: we must forward-declare any types that include invocations, since invocations can use services - graph_library: ItemStorageABC["LibraryGraph"] - graph_execution_manager: ItemStorageABC["GraphExecutionState"] + graph_library: "ItemStorageABC"["LibraryGraph"] + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] processor: "InvocationProcessorABC" def __init__( self, - model_manager: ModelManager, - events: EventServiceBase, - logger: Logger, - latents: LatentsStorageBase, - images: ImageService, - queue: InvocationQueueABC, - graph_library: ItemStorageABC["LibraryGraph"], - graph_execution_manager: ItemStorageABC["GraphExecutionState"], + model_manager: "ModelManager", + events: "EventServiceBase", + logger: "Logger", + latents: "LatentsStorageBase", + images: "ImageService", + queue: "InvocationQueueABC", + graph_library: "ItemStorageABC"["LibraryGraph"], + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], processor: "InvocationProcessorABC", - restoration: RestorationServices, - configuration: InvokeAISettings = None, + restoration: "RestorationServices", + configuration: "InvokeAISettings", ): self.model_manager = model_manager self.events = events diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index c1155ff73e..26e4929be2 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,6 +1,6 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Extra, Field, StrictStr from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -31,6 +31,8 @@ class ImageRecord(BaseModel): description="The deleted timestamp of the image." ) """The deleted timestamp of the image.""" + is_intermediate: bool = Field(description="Whether this is an intermediate image.") + """Whether this is an intermediate image.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -48,6 +50,25 @@ class ImageRecord(BaseModel): """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" +class ImageRecordChanges(BaseModel, extra=Extra.forbid): + """A set of changes to apply to an image record. + + Only limited changes are valid: + - `image_category`: change the category of an image + - `session_id`: change the session associated with an image + """ + + image_category: Optional[ImageCategory] = Field( + description="The image's new category." + ) + """The image's new category.""" + session_id: Optional[StrictStr] = Field( + default=None, + description="The image's new session ID.", + ) + """The image's new session ID.""" + + class ImageUrlsDTO(BaseModel): """The URLs for an image and its thumbnail.""" @@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at = image_dict.get("created_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) + is_intermediate = image_dict.get("is_intermediate", False) raw_metadata = image_dict.get("metadata") @@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at=created_at, updated_at=updated_at, deleted_at=deleted_at, + is_intermediate=is_intermediate, ) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index cdd9db85de..9e3b5a0a30 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -1,10 +1,7 @@ import time import traceback from threading import Event, Thread, BoundedSemaphore -from typing import Any, TypeGuard -from invokeai.app.invocations.image import ImageOutput -from invokeai.app.models.image import ImageType from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 8f5b1a8395..6f2f33e6af 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta): def __init__(self, model_info: dict, params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), + **kwargs, ): self.model_info=model_info self.params=params + self.kwargs = kwargs def generate(self, prompt: str='', @@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta): model=model, scheduler_name=generator_args.get('scheduler') ) - uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) + + # get conditioning from prompt via Compel package + uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model) + gen_class = self._generator_class() - generator = gen_class(model, self.params.precision) + generator = gen_class(model, self.params.precision, **self.kwargs) if self.params.variation_amount > 0: generator.set_variation(generator_args.get('seed'), generator_args.get('variation_amount'), @@ -276,7 +281,7 @@ class Generator: precision: str model: DiffusionPipeline - def __init__(self, model: DiffusionPipeline, precision: str): + def __init__(self, model: DiffusionPipeline, precision: str, **kwargs): self.model = model self.precision = precision self.seed = None diff --git a/invokeai/backend/generator/txt2img.py b/invokeai/backend/generator/txt2img.py index e5a96212f0..a83a8e0c31 100644 --- a/invokeai/backend/generator/txt2img.py +++ b/invokeai/backend/generator/txt2img.py @@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator import PIL.Image import torch +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from ..stable_diffusion import ( ConditioningData, PostprocessingSettings, @@ -13,8 +17,13 @@ from .base import Generator class Txt2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) + def __init__(self, model, precision, + control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None, + **kwargs): + self.control_model = control_model + if isinstance(self.control_model, list): + self.control_model = MultiControlNetModel(self.control_model) + super().__init__(model, precision, **kwargs) @torch.no_grad() def get_make_image( @@ -42,9 +51,12 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin + control_image = kwargs.get("control_image", None) + do_classifier_free_guidance = cfg_scale > 1.0 # noinspection PyTypeChecker pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.control_model = self.control_model pipeline.scheduler = sampler uc, c, extra_conditioning_info = conditioning @@ -61,6 +73,37 @@ class Txt2Img(Generator): ), ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + if control_image is not None: + if isinstance(self.control_model, ControlNetModel): + control_image = pipeline.prepare_control_image( + image=control_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + elif isinstance(self.control_model, MultiControlNetModel): + images = [] + for image_ in control_image: + image_ = self.model.prepare_control_image( + image=image_, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + images.append(image_) + control_image = images + kwargs["control_image"] = control_image + def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: pipeline_output = pipeline.image_from_embeddings( latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), @@ -68,6 +111,7 @@ class Txt2Img(Generator): num_inference_steps=steps, conditioning_data=conditioning_data, callback=step_callback, + **kwargs, ) if ( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 4ca2a5cb30..ec2902e4d6 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -2,23 +2,29 @@ from __future__ import annotations import dataclasses import inspect +import math import secrets from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel, Field import einops import PIL.Image +import numpy as np from accelerate.utils import set_seed import psutil import torch import torchvision.transforms as T from compel import EmbeddingsProvider from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( StableDiffusionImg2ImgPipeline, ) @@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.utils import PIL_INTERPOLATION from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize @@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): raise AssertionError("why was that an empty generator?") return result +@dataclass +class ControlNetData: + model: ControlNetModel = Field(default=None) + image_tensor: torch.Tensor= Field(default=None) + weight: float = Field(default=1.0) + begin_step_percent: float = Field(default=0.0) + end_step_percent: float = Field(default=1.0) @dataclass(frozen=True) class ConditioningData: @@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, precision: str = "float32", + control_model: ControlNetModel = None, ): super().__init__( vae, @@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + # FIXME: can't currently register control module + # control_model=control_model, ) self.invokeai_diffuser = InvokeAIDiffuserComponent( self.unet, self._unet_forward, is_running_diffusers=True @@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group.install(*self._submodels) + self.control_model = control_model def _adjust_memory_efficient_attention(self, latents: torch.Tensor): """ @@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, callback: Callable[[PipelineIntermediateState], None] = None, run_id=None, + **kwargs, ) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, run_id=run_id, callback=callback, + **kwargs, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, + control_data: List[ControlNetData] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device('cpu') @@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance=additional_guidance, run_id=run_id, callback=callback, + control_data=control_data, + **kwargs, ) return result.latents, result.attention_map_saver @@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, + **kwargs, ): self._adjust_memory_efficient_attention(latents) if run_id is None: @@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None - + # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step( @@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=i, total_step_count=len(timesteps), additional_guidance=additional_guidance, + control_data=control_data, + **kwargs, ) latents = step_output.prev_sample @@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, + **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] - if additional_guidance is None: additional_guidance = [] @@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) + # default is no controlnet, so set controlnet processing output to None + down_block_res_samples, mid_block_res_sample = None, None + + if control_data is not None: + if conditioning_data.guidance_scale > 1.0: + # expand the latents input to control model if doing classifier free guidance + # (which I think for now is always true, there is conditional elsewhere that stops execution if + # classifier_free_guidance is <= 1.0 ?) + latent_control_input = torch.cat([latent_model_input] * 2) + else: + latent_control_input = latent_model_input + # control_data should be type List[ControlNetData] + # this loop covers both ControlNet (one ControlNetData in list) + # and MultiControlNet (multiple ControlNetData in list) + for i, control_datum in enumerate(control_data): + # print("controlnet", i, "==>", type(control_datum)) + first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) + last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) + # only apply controlnet if current step is within the controlnet's begin/end step range + if step_index >= first_control_step and step_index <= last_control_step: + # print("running controlnet", i, "for step", step_index) + down_samples, mid_sample = control_datum.model( + sample=latent_control_input, + timestep=timestep, + encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]), + controlnet_cond=control_datum.image_tensor, + conditioning_scale=control_datum.weight, + # cross_attention_kwargs, + guess_mode=False, + return_dict=False, + ) + if down_block_res_samples is None and mid_block_res_sample is None: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( latent_model_input, @@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data.guidance_scale, step_index=step_index, total_step_count=total_step_count, + down_block_additional_residuals=down_block_res_samples, # from controlnet(s) + mid_block_additional_residual=mid_block_res_sample, # from controlnet(s) ) # compute the previous noisy sample x_t -> x_t-1 @@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): t, text_embeddings, cross_attention_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: @@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # First three args should be positional, not keywords, so torch hooks can see them. return self.unet( - latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs + latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs, + **kwargs, ).sample def img2img_from_embeddings( @@ -728,7 +803,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id=None, callback=None, - ) -> InvokeAIStableDiffusionPipelineOutput: + ) -> InvokeAIStableDiffusionPipelineOutput: timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength) result_latents, result_attention_maps = self.latents_from_embeddings( latents=initial_latents if strength < 1.0 else torch.zeros_like( @@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): debug_image( img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True ) + + # Copied from diffusers pipeline_stable_diffusion_controlnet.py + # Returns torch.Tensor of shape (batch_size, 3, height, width) + def prepare_control_image( + self, + image, + # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? + # latents, + width=512, # should be 8 * latent.shape[3] + height=512, # should be 8 * latent height[2] + batch_size=1, + num_images_per_prompt=1, + device="cuda", + dtype=torch.float16, + do_classifier_free_guidance=True, + ): + + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + image = images + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + return image diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 4131837b41..d05565c506 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent: unconditional_guidance_scale: float, step_index: Optional[int] = None, total_step_count: Optional[int] = None, + **kwargs, ): """ :param x: current latents @@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent: if wants_hybrid_conditioning: unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) elif wants_cross_attention_control: ( @@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) elif self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) else: @@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) combined_next_x = self._combine( @@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) if conditioned_next_x.device.type == "mps": @@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent: sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor, + **kwargs, ): # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) if conditioned_next_x.device.type == "mps": # prevent a result filled with zeros. seems to be a torch bug. conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict) x_twice = torch.cat([x] * 2) @@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent: else: both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ).chunk(2) return unconditioned_next_x, conditioned_next_x @@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): if self.is_running_diffusers: return self._apply_cross_attention_controlled_conditioning__diffusers( @@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) else: return self._apply_cross_attention_controlled_conditioning__compvis( @@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) def _apply_cross_attention_controlled_conditioning__diffusers( @@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): context: Context = self.cross_attention_control_context @@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent: sigma, unconditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) # do requested cross attention types for conditioning (positive prompt) @@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) return unconditioned_next_x, conditioned_next_x @@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) @@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent: context: Context = self.cross_attention_control_context try: - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) # process x using the original prompt, saving the attention maps # print("saving attention maps for", cross_attention_control_types_to_do) for ca_type in cross_attention_control_types_to_do: context.request_save_attention_maps(ca_type) - _ = self.model_forward_callback(x, sigma, conditioning) + _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,) context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied @@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent: self.conditioning.cross_attention_control_args.edited_conditioning ) conditioned_next_x = self.model_forward_callback( - x, sigma, edited_conditioning + x, sigma, edited_conditioning, **kwargs, ) context.clear_requests(cleanup=True) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index f23e83a191..1fbc2f978c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit'; import type { RootState, AppDispatch } from '../../store'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; -import { addImageResultReceivedListener } from './listeners/invocationComplete'; -import { addImageUploadedListener } from './listeners/imageUploaded'; -import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; +import { + addImageUploadedFulfilledListener, + addImageUploadedRejectedListener, +} from './listeners/imageUploaded'; +import { + addImageDeletedFulfilledListener, + addImageDeletedPendingListener, + addImageDeletedRejectedListener, + addRequestedImageDeletionListener, +} from './listeners/imageDeleted'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; @@ -19,6 +26,47 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasMergedListener } from './listeners/canvasMerged'; +import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress'; +import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete'; +import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete'; +import { addInvocationErrorListener } from './listeners/socketio/invocationError'; +import { addInvocationStartedListener } from './listeners/socketio/invocationStarted'; +import { addSocketConnectedListener } from './listeners/socketio/socketConnected'; +import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; +import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; +import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; +import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke'; +import { + addImageMetadataReceivedFulfilledListener, + addImageMetadataReceivedRejectedListener, +} from './listeners/imageMetadataReceived'; +import { + addImageUrlsReceivedFulfilledListener, + addImageUrlsReceivedRejectedListener, +} from './listeners/imageUrlsReceived'; +import { + addSessionCreatedFulfilledListener, + addSessionCreatedPendingListener, + addSessionCreatedRejectedListener, +} from './listeners/sessionCreated'; +import { + addSessionInvokedFulfilledListener, + addSessionInvokedPendingListener, + addSessionInvokedRejectedListener, +} from './listeners/sessionInvoked'; +import { + addSessionCanceledFulfilledListener, + addSessionCanceledPendingListener, + addSessionCanceledRejectedListener, +} from './listeners/sessionCanceled'; +import { + addReceivedResultImagesPageFulfilledListener, + addReceivedResultImagesPageRejectedListener, +} from './listeners/receivedResultImagesPage'; +import { + addReceivedUploadImagesPageFulfilledListener, + addReceivedUploadImagesPageRejectedListener, +} from './listeners/receivedUploadImagesPage'; export const listenerMiddleware = createListenerMiddleware(); @@ -38,17 +86,67 @@ export type AppListenerEffect = ListenerEffect< AppDispatch >; -addImageUploadedListener(); -addInitialImageSelectedListener(); -addImageResultReceivedListener(); -addRequestedImageDeletionListener(); +// Image uploaded +addImageUploadedFulfilledListener(); +addImageUploadedRejectedListener(); +addInitialImageSelectedListener(); + +// Image deleted +addRequestedImageDeletionListener(); +addImageDeletedPendingListener(); +addImageDeletedFulfilledListener(); +addImageDeletedRejectedListener(); + +// Image metadata +addImageMetadataReceivedFulfilledListener(); +addImageMetadataReceivedRejectedListener(); + +// Image URLs +addImageUrlsReceivedFulfilledListener(); +addImageUrlsReceivedRejectedListener(); + +// User Invoked addUserInvokedCanvasListener(); addUserInvokedNodesListener(); addUserInvokedTextToImageListener(); addUserInvokedImageToImageListener(); +addSessionReadyToInvokeListener(); +// Canvas actions addCanvasSavedToGalleryListener(); addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); + +// socketio +addGeneratorProgressListener(); +addGraphExecutionStateCompleteListener(); +addInvocationCompleteListener(); +addInvocationErrorListener(); +addInvocationStartedListener(); +addSocketConnectedListener(); +addSocketDisconnectedListener(); +addSocketSubscribedListener(); +addSocketUnsubscribedListener(); + +// Session Created +addSessionCreatedPendingListener(); +addSessionCreatedFulfilledListener(); +addSessionCreatedRejectedListener(); + +// Session Invoked +addSessionInvokedPendingListener(); +addSessionInvokedFulfilledListener(); +addSessionInvokedRejectedListener(); + +// Session Canceled +addSessionCanceledPendingListener(); +addSessionCanceledFulfilledListener(); +addSessionCanceledRejectedListener(); + +// Gallery pages +addReceivedResultImagesPageFulfilledListener(); +addReceivedResultImagesPageRejectedListener(); +addReceivedUploadImagesPageFulfilledListener(); +addReceivedUploadImagesPageRejectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index 1e2d99541c..fbc9c9c225 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -52,7 +52,6 @@ export const addCanvasMergedListener = () => { dispatch( imageUploaded({ - imageType: 'intermediates', formData: { file: new File([blob], filename, { type: 'image/png' }), }, @@ -65,7 +64,7 @@ export const addCanvasMergedListener = () => { action.meta.arg.formData.file.name === filename ); - const mergedCanvasImage = payload.response; + const mergedCanvasImage = payload; dispatch( setMergedCanvas({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index d8237d1d5c..2df3dacea2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -29,7 +29,6 @@ export const addCanvasSavedToGalleryListener = () => { dispatch( imageUploaded({ - imageType: 'results', formData: { file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 42a62b3d80..cd4771b96a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -4,9 +4,14 @@ import { imageDeleted } from 'services/thunks/image'; import { log } from 'app/logging/useLogger'; import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; +import { resultsAdapter } from 'features/gallery/store/resultsSlice'; const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); +/** + * Called when the user requests an image deletion + */ export const addRequestedImageDeletionListener = () => { startAppListening({ actionCreator: requestedImageDeletion, @@ -19,11 +24,6 @@ export const addRequestedImageDeletionListener = () => { const { image_name, image_type } = image; - if (image_type !== 'uploads' && image_type !== 'results') { - moduleLog.warn({ data: image }, `Invalid image type ${image_type}`); - return; - } - const selectedImageName = getState().gallery.selectedImage?.image_name; if (selectedImageName === image_name) { @@ -57,3 +57,49 @@ export const addRequestedImageDeletionListener = () => { }, }); }; + +/** + * Called when the actual delete request is sent to the server + */ +export const addImageDeletedPendingListener = () => { + startAppListening({ + actionCreator: imageDeleted.pending, + effect: (action, { dispatch, getState }) => { + const { imageName, imageType } = action.meta.arg; + // Preemptively remove the image from the gallery + if (imageType === 'uploads') { + uploadsAdapter.removeOne(getState().uploads, imageName); + } + if (imageType === 'results') { + resultsAdapter.removeOne(getState().results, imageName); + } + }, + }); +}; + +/** + * Called on successful delete + */ +export const addImageDeletedFulfilledListener = () => { + startAppListening({ + actionCreator: imageDeleted.fulfilled, + effect: (action, { dispatch, getState }) => { + moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted'); + }, + }); +}; + +/** + * Called on failed delete + */ +export const addImageDeletedRejectedListener = () => { + startAppListening({ + actionCreator: imageDeleted.rejected, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Unable to delete image' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts new file mode 100644 index 0000000000..c93ed2820f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -0,0 +1,43 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { + ResultsImageDTO, + resultUpserted, +} from 'features/gallery/store/resultsSlice'; +import { + UploadsImageDTO, + uploadUpserted, +} from 'features/gallery/store/uploadsSlice'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageMetadataReceivedFulfilledListener = () => { + startAppListening({ + actionCreator: imageMetadataReceived.fulfilled, + effect: (action, { getState, dispatch }) => { + const image = action.payload; + moduleLog.debug({ data: { image } }, 'Image metadata received'); + + if (image.image_type === 'results') { + dispatch(resultUpserted(action.payload as ResultsImageDTO)); + } + + if (image.image_type === 'uploads') { + dispatch(uploadUpserted(action.payload as UploadsImageDTO)); + } + }, + }); +}; + +export const addImageMetadataReceivedRejectedListener = () => { + startAppListening({ + actionCreator: imageMetadataReceived.rejected, + effect: (action, { getState, dispatch }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Problem receiving image metadata' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 1d66166c12..5b177eae91 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,25 +1,31 @@ import { startAppListening } from '..'; -import { uploadAdded } from 'features/gallery/store/uploadsSlice'; +import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; import { initialImageSelected } from 'features/parameters/store/actions'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; -import { resultAdded } from 'features/gallery/store/resultsSlice'; +import { resultUpserted } from 'features/gallery/store/resultsSlice'; import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards'; +import { log } from 'app/logging/useLogger'; -export const addImageUploadedListener = () => { +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUploadedFulfilledListener = () => { startAppListening({ predicate: (action): action is ReturnType => imageUploaded.fulfilled.match(action) && - action.payload.response.image_type !== 'intermediates', + action.payload.is_intermediate === false, effect: (action, { dispatch, getState }) => { - const { response: image } = action.payload; + const image = action.payload; + + moduleLog.debug({ arg: '', image }, 'Image uploaded'); const state = getState(); + // Handle uploads if (isUploadsImageDTO(image)) { - dispatch(uploadAdded(image)); + dispatch(uploadUpserted(image)); dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); @@ -36,9 +42,26 @@ export const addImageUploadedListener = () => { } } + // Handle results + // TODO: Can this ever happen? I don't think so... if (isResultsImageDTO(image)) { - dispatch(resultAdded(image)); + dispatch(resultUpserted(image)); } }, }); }; + +export const addImageUploadedRejectedListener = () => { + startAppListening({ + actionCreator: imageUploaded.rejected, + effect: (action, { dispatch }) => { + dispatch( + addToast({ + title: 'Image Upload Failed', + description: action.error.message, + status: 'error', + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts new file mode 100644 index 0000000000..4ff2a02118 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts @@ -0,0 +1,51 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageUrlsReceived } from 'services/thunks/image'; +import { resultsAdapter } from 'features/gallery/store/resultsSlice'; +import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUrlsReceivedFulfilledListener = () => { + startAppListening({ + actionCreator: imageUrlsReceived.fulfilled, + effect: (action, { getState, dispatch }) => { + const image = action.payload; + moduleLog.debug({ data: { image } }, 'Image URLs received'); + + const { image_type, image_name, image_url, thumbnail_url } = image; + + if (image_type === 'results') { + resultsAdapter.updateOne(getState().results, { + id: image_name, + changes: { + image_url, + thumbnail_url, + }, + }); + } + + if (image_type === 'uploads') { + uploadsAdapter.updateOne(getState().uploads, { + id: image_name, + changes: { + image_url, + thumbnail_url, + }, + }); + } + }, + }); +}; + +export const addImageUrlsReceivedRejectedListener = () => { + startAppListening({ + actionCreator: imageUrlsReceived.rejected, + effect: (action, { getState, dispatch }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Problem getting image URLs' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts deleted file mode 100644 index 0222eea93c..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts +++ /dev/null @@ -1,62 +0,0 @@ -import { invocationComplete } from 'services/events/actions'; -import { isImageOutput } from 'services/types/guards'; -import { - imageMetadataReceived, - imageUrlsReceived, -} from 'services/thunks/image'; -import { startAppListening } from '..'; -import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; - -const nodeDenylist = ['dataURL_image']; - -export const addImageResultReceivedListener = () => { - startAppListening({ - predicate: (action) => { - if ( - invocationComplete.match(action) && - isImageOutput(action.payload.data.result) - ) { - return true; - } - return false; - }, - effect: async (action, { getState, dispatch, take }) => { - if (!invocationComplete.match(action)) { - return; - } - - const { data } = action.payload; - const { result, node, graph_execution_state_id } = data; - - if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { - const { image_name, image_type } = result.image; - - dispatch( - imageUrlsReceived({ imageName: image_name, imageType: image_type }) - ); - - dispatch( - imageMetadataReceived({ - imageName: image_name, - imageType: image_type, - }) - ); - - // Handle canvas image - if ( - graph_execution_state_id === - getState().canvas.layerState.stagingArea.sessionId - ) { - const [{ payload: image }] = await take( - ( - action - ): action is ReturnType => - imageMetadataReceived.fulfilled.match(action) && - action.payload.image_name === image_name - ); - dispatch(addImageToStagingArea(image)); - } - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts new file mode 100644 index 0000000000..bcdd11ef97 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts @@ -0,0 +1,33 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { receivedResultImagesPage } from 'services/thunks/gallery'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'gallery' }); + +export const addReceivedResultImagesPageFulfilledListener = () => { + startAppListening({ + actionCreator: receivedResultImagesPage.fulfilled, + effect: (action, { getState, dispatch }) => { + const page = action.payload; + moduleLog.debug( + { data: { page } }, + `Received ${page.items.length} results` + ); + }, + }); +}; + +export const addReceivedResultImagesPageRejectedListener = () => { + startAppListening({ + actionCreator: receivedResultImagesPage.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + moduleLog.debug( + { data: { error: serializeError(action.payload.error) } }, + 'Problem receiving results' + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts new file mode 100644 index 0000000000..68813aae27 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts @@ -0,0 +1,33 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { receivedUploadImagesPage } from 'services/thunks/gallery'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'gallery' }); + +export const addReceivedUploadImagesPageFulfilledListener = () => { + startAppListening({ + actionCreator: receivedUploadImagesPage.fulfilled, + effect: (action, { getState, dispatch }) => { + const page = action.payload; + moduleLog.debug( + { data: { page } }, + `Received ${page.items.length} uploads` + ); + }, + }); +}; + +export const addReceivedUploadImagesPageRejectedListener = () => { + startAppListening({ + actionCreator: receivedUploadImagesPage.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + moduleLog.debug( + { data: { error: serializeError(action.payload.error) } }, + 'Problem receiving uploads' + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts new file mode 100644 index 0000000000..6274ad4dc8 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts @@ -0,0 +1,48 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionCanceled } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionCanceledPendingListener = () => { + startAppListening({ + actionCreator: sessionCanceled.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionCanceledFulfilledListener = () => { + startAppListening({ + actionCreator: sessionCanceled.fulfilled, + effect: (action, { getState, dispatch }) => { + const { sessionId } = action.meta.arg; + moduleLog.debug( + { data: { sessionId } }, + `Session canceled (${sessionId})` + ); + }, + }); +}; + +export const addSessionCanceledRejectedListener = () => { + startAppListening({ + actionCreator: sessionCanceled.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem canceling session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts new file mode 100644 index 0000000000..fb8a64d2e3 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -0,0 +1,45 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionCreated } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionCreatedPendingListener = () => { + startAppListening({ + actionCreator: sessionCreated.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionCreatedFulfilledListener = () => { + startAppListening({ + actionCreator: sessionCreated.fulfilled, + effect: (action, { getState, dispatch }) => { + const session = action.payload; + moduleLog.debug({ data: { session } }, `Session created (${session.id})`); + }, + }); +}; + +export const addSessionCreatedRejectedListener = () => { + startAppListening({ + actionCreator: sessionCreated.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem creating session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts new file mode 100644 index 0000000000..272d1d9e1d --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -0,0 +1,48 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionInvoked } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionInvokedPendingListener = () => { + startAppListening({ + actionCreator: sessionInvoked.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionInvokedFulfilledListener = () => { + startAppListening({ + actionCreator: sessionInvoked.fulfilled, + effect: (action, { getState, dispatch }) => { + const { sessionId } = action.meta.arg; + moduleLog.debug( + { data: { sessionId } }, + `Session invoked (${sessionId})` + ); + }, + }); +}; + +export const addSessionInvokedRejectedListener = () => { + startAppListening({ + actionCreator: sessionInvoked.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem invoking session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts new file mode 100644 index 0000000000..8d4262e7da --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts @@ -0,0 +1,22 @@ +import { startAppListening } from '..'; +import { sessionInvoked } from 'services/thunks/session'; +import { log } from 'app/logging/useLogger'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionReadyToInvokeListener = () => { + startAppListening({ + actionCreator: sessionReadyToInvoke, + effect: (action, { getState, dispatch }) => { + const { sessionId } = getState().system; + if (sessionId) { + moduleLog.debug( + { sessionId }, + `Session ready to invoke (${sessionId})})` + ); + dispatch(sessionInvoked({ sessionId })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts new file mode 100644 index 0000000000..341b5e46d3 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts @@ -0,0 +1,28 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { generatorProgress } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGeneratorProgressListener = () => { + startAppListening({ + actionCreator: generatorProgress, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored generator progress for canceled session' + ); + return; + } + + moduleLog.trace( + action.payload, + `Generator progress (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts new file mode 100644 index 0000000000..a66a7fb547 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts @@ -0,0 +1,17 @@ +import { log } from 'app/logging/useLogger'; +import { graphExecutionStateComplete } from 'services/events/actions'; +import { startAppListening } from '../..'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGraphExecutionStateCompleteListener = () => { + startAppListening({ + actionCreator: graphExecutionStateComplete, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Session invocation complete (${action.payload.data.graph_execution_state_id})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts new file mode 100644 index 0000000000..95e6d831c0 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts @@ -0,0 +1,74 @@ +import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { invocationComplete } from 'services/events/actions'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { sessionCanceled } from 'services/thunks/session'; +import { isImageOutput } from 'services/types/guards'; +import { progressImageSet } from 'features/system/store/systemSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; + +const moduleLog = log.child({ namespace: 'socketio' }); +const nodeDenylist = ['dataURL_image']; + +export const addInvocationCompleteListener = () => { + startAppListening({ + actionCreator: invocationComplete, + effect: async (action, { dispatch, getState, take }) => { + moduleLog.debug( + action.payload, + `Invocation complete (${action.payload.data.node.type})` + ); + + const sessionId = action.payload.data.graph_execution_state_id; + + const { cancelType, isCancelScheduled } = getState().system; + + // Handle scheduled cancelation + if (cancelType === 'scheduled' && isCancelScheduled) { + dispatch(sessionCanceled({ sessionId })); + } + + const { data } = action.payload; + const { result, node, graph_execution_state_id } = data; + + // This complete event has an associated image output + if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { + const { image_name, image_type } = result.image; + + // Get its metadata + dispatch( + imageMetadataReceived({ + imageName: image_name, + imageType: image_type, + }) + ); + + const [{ payload: imageDTO }] = await take( + imageMetadataReceived.fulfilled.match + ); + + if (getState().gallery.shouldAutoSwitchToNewImages) { + dispatch(imageSelected(imageDTO)); + } + + // Handle canvas image + if ( + graph_execution_state_id === + getState().canvas.layerState.stagingArea.sessionId + ) { + const [{ payload: image }] = await take( + ( + action + ): action is ReturnType => + imageMetadataReceived.fulfilled.match(action) && + action.payload.image_name === image_name + ); + dispatch(addImageToStagingArea(image)); + } + + dispatch(progressImageSet(null)); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts new file mode 100644 index 0000000000..3a98af120a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { invocationError } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationErrorListener = () => { + startAppListening({ + actionCreator: invocationError, + effect: (action, { dispatch, getState }) => { + moduleLog.error( + action.payload, + `Invocation error (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts new file mode 100644 index 0000000000..f898c62b23 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts @@ -0,0 +1,28 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { invocationStarted } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationStartedListener = () => { + startAppListening({ + actionCreator: invocationStarted, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored invocation started for canceled session' + ); + return; + } + + moduleLog.debug( + action.payload, + `Invocation started (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts new file mode 100644 index 0000000000..bc9ecbec1e --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -0,0 +1,43 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketConnected } from 'services/events/actions'; +import { + receivedResultImagesPage, + receivedUploadImagesPage, +} from 'services/thunks/gallery'; +import { receivedModels } from 'services/thunks/model'; +import { receivedOpenAPISchema } from 'services/thunks/schema'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketConnectedListener = () => { + startAppListening({ + actionCreator: socketConnected, + effect: (action, { dispatch, getState }) => { + const { timestamp } = action.payload; + + moduleLog.debug({ timestamp }, 'Connected'); + + const { results, uploads, models, nodes, config } = getState(); + + const { disabledTabs } = config; + + // These thunks need to be dispatch in middleware; cannot handle in a reducer + if (!results.ids.length) { + dispatch(receivedResultImagesPage()); + } + + if (!uploads.ids.length) { + dispatch(receivedUploadImagesPage()); + } + + if (!models.ids.length) { + dispatch(receivedModels()); + } + + if (!nodes.schema && !disabledTabs.includes('nodes')) { + dispatch(receivedOpenAPISchema()); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts new file mode 100644 index 0000000000..131c3ba18f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts @@ -0,0 +1,14 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketDisconnected } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketDisconnectedListener = () => { + startAppListening({ + actionCreator: socketDisconnected, + effect: (action, { dispatch, getState }) => { + moduleLog.debug(action.payload, 'Disconnected'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts new file mode 100644 index 0000000000..400f8a1689 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketSubscribed } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketSubscribedListener = () => { + startAppListening({ + actionCreator: socketSubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Subscribed (${action.payload.sessionId}))` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts new file mode 100644 index 0000000000..af15c55d42 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketUnsubscribed } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketUnsubscribedListener = () => { + startAppListening({ + actionCreator: socketUnsubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Unsubscribed (${action.payload.sessionId})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 2ebd3684e9..ae388b85cf 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -1,9 +1,9 @@ import { startAppListening } from '..'; -import { sessionCreated, sessionInvoked } from 'services/thunks/session'; +import { sessionCreated } from 'services/thunks/session'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { log } from 'app/logging/useLogger'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { imageUploaded } from 'services/thunks/image'; +import { imageUpdated, imageUploaded } from 'services/thunks/image'; import { v4 as uuidv4 } from 'uuid'; import { Graph } from 'services/api'; import { @@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); /** - * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. - * It is also responsible for uploading the base and mask layers to the server. + * This listener is responsible invoking the canvas. This involves a number of steps: + * + * 1. Generate image blobs from the canvas layers + * 2. Determine the generation mode from the layers (txt2img, img2img, inpaint) + * 3. Build the canvas graph + * 4. Create the session with the graph + * 5. Upload the init image if necessary + * 6. Upload the mask image if necessary + * 7. Update the init and mask images with the session ID + * 8. Initialize the staging area if not yet initialized + * 9. Dispatch the sessionReadyToInvoke action to invoke the session */ export const addUserInvokedCanvasListener = () => { startAppListening({ @@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => { const { rangeNode, iterateNode, baseNode, edges } = graphComponents; - // Upload the base layer, to be used as init image - const baseFilename = `${uuidv4()}.png`; - - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([baseBlob], baseFilename, { type: 'image/png' }), - }, - }) - ); - - if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { - const [{ payload: basePayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === baseFilename - ); - - const { image_name: baseName, image_type: baseType } = - basePayload.response; - - baseNode.image = { - image_name: baseName, - image_type: baseType, - }; - } - - // Upload the mask layer image - const maskFilename = `${uuidv4()}.png`; - - if (baseNode.type === 'inpaint') { - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([maskBlob], maskFilename, { type: 'image/png' }), - }, - }) - ); - - const [{ payload: maskPayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === maskFilename - ); - - const { image_name: maskName, image_type: maskType } = - maskPayload.response; - - baseNode.mask = { - image_name: maskName, - image_type: maskType, - }; - } - - // Assemble! + // Assemble! Note that this graph *does not have the init or mask image set yet!* const nodes: Graph['nodes'] = { [rangeNode.id]: rangeNode, [iterateNode.id]: iterateNode, @@ -136,15 +90,90 @@ export const addUserInvokedCanvasListener = () => { const graph = { nodes, edges }; dispatch(canvasGraphBuilt(graph)); - moduleLog({ data: graph }, 'Canvas graph built'); - // Actually create the session + moduleLog.debug({ data: graph }, 'Canvas graph built'); + + // If we are generating img2img or inpaint, we need to upload the init images + if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { + const baseFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([baseBlob], baseFilename, { type: 'image/png' }), + }, + isIntermediate: true, + }) + ); + + // Wait for the image to be uploaded + const [{ payload: baseImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === baseFilename + ); + + // Update the base node with the image name and type + baseNode.image = { + image_name: baseImageDTO.image_name, + image_type: baseImageDTO.image_type, + }; + } + + // For inpaint, we also need to upload the mask layer + if (baseNode.type === 'inpaint') { + const maskFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([maskBlob], maskFilename, { type: 'image/png' }), + }, + isIntermediate: true, + }) + ); + + // Wait for the mask to be uploaded + const [{ payload: maskImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === maskFilename + ); + + // Update the base node with the image name and type + baseNode.mask = { + image_name: maskImageDTO.image_name, + image_type: maskImageDTO.image_type, + }; + } + + // Create the session and wait for response dispatch(sessionCreated({ graph })); + const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match); + const sessionId = sessionCreatedAction.payload.id; - // Wait for the session to be invoked (this is just the HTTP request to start processing) - const [{ meta }] = await take(sessionInvoked.fulfilled.match); + // Associate the init image with the session, now that we have the session ID + if ( + (baseNode.type === 'img2img' || baseNode.type === 'inpaint') && + baseNode.image + ) { + dispatch( + imageUpdated({ + imageName: baseNode.image.image_name, + imageType: baseNode.image.image_type, + requestBody: { session_id: sessionId }, + }) + ); + } - const { sessionId } = meta.arg; + // Associate the mask image with the session, now that we have the session ID + if (baseNode.type === 'inpaint' && baseNode.mask) { + dispatch( + imageUpdated({ + imageName: baseNode.mask.image_name, + imageType: baseNode.mask.image_type, + requestBody: { session_id: sessionId }, + }) + ); + } if (!state.canvas.layerState.stagingArea.boundingBox) { dispatch( @@ -158,7 +187,11 @@ export const addUserInvokedCanvasListener = () => { ); } + // Flag the session with the canvas session ID dispatch(canvasSessionIdChanged(sessionId)); + + // We are ready to invoke the session! + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts index e747aefa08..7dcbe8a41d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'img2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildImageToImageGraph(state); dispatch(imageToImageGraphBuilt(graph)); - moduleLog({ data: graph }, 'Image to Image graph built'); + moduleLog.debug({ data: graph }, 'Image to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts index 01e532d5ff..6fda3db0d6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts @@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra import { log } from 'app/logging/useLogger'; import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'nodes', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildNodesGraph(state); dispatch(nodesGraphBuilt(graph)); - moduleLog({ data: graph }, 'Nodes graph built'); + moduleLog.debug({ data: graph }, 'Nodes graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts index e3eb5d0b38..6042d86cb7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'txt2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildTextToImageGraph(state); + dispatch(textToImageGraphBuilt(graph)); - moduleLog({ data: graph }, 'Text to Image graph built'); + + moduleLog.debug({ data: graph }, 'Text to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index b89615b2c0..4e9c154f3a 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,6 +16,7 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; +// import sessionReducer from 'features/system/store/sessionSlice'; import configReducer from 'features/system/store/configSlice'; import uiReducer from 'features/ui/store/uiSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; @@ -46,6 +47,7 @@ const allReducers = { ui: uiReducer, uploads: uploadsReducer, hotkeys: hotkeysReducer, + // session: sessionReducer, }; const rootReducer = combineReducers(allReducers); diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index db6b9ee517..628d44b6f1 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -68,7 +68,6 @@ const ImageUploader = (props: ImageUploaderProps) => { async (file: File) => { dispatch( imageUploaded({ - imageType: 'uploads', formData: { file }, activeTabName, }) diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts index c27833218b..bb3999d6d0 100644 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ b/invokeai/frontend/web/src/common/util/parseMetadata.ts @@ -1,5 +1,10 @@ import { forEach, size } from 'lodash-es'; -import { ImageField, LatentsField, ConditioningField } from 'services/api'; +import { + ImageField, + LatentsField, + ConditioningField, + ControlField, +} from 'services/api'; const OBJECT_TYPESTRING = '[object Object]'; const STRING_TYPESTRING = '[object String]'; @@ -98,6 +103,24 @@ const parseConditioningField = ( }; }; +const parseControlField = (controlField: unknown): ControlField | undefined => { + // Must be an object + if (!isObject(controlField)) { + return; + } + + // A ControlField must have a `control` + if (!('control' in controlField)) { + return; + } + // console.log(typeof controlField.control); + + // Build a valid ControlField + return { + control: controlField.control, + }; +}; + type NodeMetadata = { [key: string]: | string @@ -105,7 +128,8 @@ type NodeMetadata = { | boolean | ImageField | LatentsField - | ConditioningField; + | ConditioningField + | ControlField; }; type InvokeAIMetadata = { @@ -131,7 +155,7 @@ export const parseNodeMetadata = ( return; } - // the only valid object types are ImageField, LatentsField and ConditioningField + // the only valid object types are ImageField, LatentsField, ConditioningField, ControlField if (isObject(nodeItem)) { if ('image_name' in nodeItem || 'image_type' in nodeItem) { const imageField = parseImageField(nodeItem); @@ -156,6 +180,14 @@ export const parseNodeMetadata = ( } return; } + + if ('control' in nodeItem) { + const controlField = parseControlField(nodeItem); + if (controlField) { + parsed[nodeKey] = controlField; + } + return; + } } // otherwise we accept any string, number or boolean diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index f5265b54db..c19a404a37 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -109,8 +109,9 @@ const currentImageButtonsSelector = createSelector( isLightboxOpen, shouldHidePreview, image: selectedImage, - seed: selectedImage?.metadata?.invokeai?.node?.seed, - prompt: selectedImage?.metadata?.invokeai?.node?.prompt, + seed: selectedImage?.metadata?.seed, + prompt: selectedImage?.metadata?.positive_conditioning, + negativePrompt: selectedImage?.metadata?.negative_conditioning, }; }, { @@ -245,13 +246,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { ); const handleUseSeed = useCallback(() => { - recallSeed(image?.metadata?.invokeai?.node?.seed); + recallSeed(image?.metadata?.seed); }, [image, recallSeed]); useHotkeys('s', handleUseSeed, [image]); const handleUsePrompt = useCallback(() => { - recallPrompt(image?.metadata?.invokeai?.node?.prompt); + recallPrompt( + image?.metadata?.positive_conditioning, + image?.metadata?.negative_conditioning + ); }, [image, recallPrompt]); useHotkeys('p', handleUsePrompt, [image]); @@ -454,7 +458,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { {t('parameters.copyImageToLink')} - + } size="sm" w="100%"> {t('parameters.downloadImage')} @@ -500,7 +504,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { icon={} tooltip={`${t('parameters.usePrompt')} (P)`} aria-label={`${t('parameters.usePrompt')} (P)`} - isDisabled={!image?.metadata?.invokeai?.node?.prompt} + isDisabled={!image?.metadata?.positive_conditioning} onClick={handleUsePrompt} /> @@ -508,7 +512,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { icon={} tooltip={`${t('parameters.useSeed')} (S)`} aria-label={`${t('parameters.useSeed')} (S)`} - isDisabled={!image?.metadata?.invokeai?.node?.seed} + isDisabled={!image?.metadata?.seed} onClick={handleUseSeed} /> @@ -517,9 +521,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { tooltip={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`} isDisabled={ - !['txt2img', 'img2img', 'inpaint'].includes( - String(image?.metadata?.invokeai?.node?.type) - ) + // not sure what this list should be + !['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type)) } onClick={handleClickUseAllParameters} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx index 04fecac463..ed427f4984 100644 --- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx @@ -155,7 +155,10 @@ const HoverableImage = memo((props: HoverableImageProps) => { // Recall parameters handlers const handleRecallPrompt = useCallback(() => { - recallPrompt(image.metadata?.positive_conditioning); + recallPrompt( + image.metadata?.positive_conditioning, + image.metadata?.negative_conditioning + ); }, [image, recallPrompt]); const handleRecallSeed = useCallback(() => { @@ -248,7 +251,8 @@ const HoverableImage = memo((props: HoverableImageProps) => { icon={} onClickCapture={handleUseAllParameters} isDisabled={ - !['txt2img', 'img2img', 'inpaint'].includes( + // what should these be + !['t2l', 'l2l', 'inpaint'].includes( String(image?.metadata?.type) ) } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx index b4bf9a6d25..b01191105e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx @@ -8,29 +8,20 @@ import { Text, Tooltip, } from '@chakra-ui/react'; -import * as InvokeAI from 'app/types/invokeai'; import { useAppDispatch } from 'app/store/storeHooks'; import { useGetUrl } from 'common/util/getUrl'; import promptToString from 'common/util/promptToString'; -import { seedWeightsToString } from 'common/util/seedWeightPairs'; -import useSetBothPrompts from 'features/parameters/hooks/usePrompt'; import { setCfgScale, setHeight, setImg2imgStrength, setNegativePrompt, - setPerlin, setPositivePrompt, setScheduler, - setSeamless, setSeed, - setSeedWeights, - setShouldFitToWidthHeight, setSteps, - setThreshold, setWidth, } from 'features/parameters/store/generationSlice'; -import { setHiresFix } from 'features/parameters/store/postprocessingSlice'; import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; import { memo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; @@ -39,7 +30,6 @@ import { FaCopy } from 'react-icons/fa'; import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { ImageDTO } from 'services/api'; -import { filter } from 'lodash-es'; import { Scheduler } from 'app/constants'; type MetadataItemProps = { @@ -126,8 +116,6 @@ const memoEqualityCheck = ( const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const dispatch = useAppDispatch(); - const setBothPrompts = useSetBothPrompts(); - useHotkeys('esc', () => { dispatch(setShouldShowImageDetails(false)); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index d0d25f8bc6..fcf8359187 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -33,7 +33,7 @@ export const nextPrevImageButtonsSelector = createSelector( } const currentImageIndex = state[currentCategory].ids.findIndex( - (i) => i === selectedImage.name + (i) => i === selectedImage.image_name ); const nextImageIndex = clamp( diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts index 125f4ff5d5..36f4c49401 100644 --- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts @@ -1,14 +1,13 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { + PayloadAction, + createEntityAdapter, + createSlice, +} from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { receivedResultImagesPage, IMAGES_PER_PAGE, } from 'services/thunks/gallery'; -import { - imageDeleted, - imageMetadataReceived, - imageUrlsReceived, -} from 'services/thunks/image'; import { ImageDTO } from 'services/api'; import { dateComparator } from 'common/util/dateComparator'; @@ -26,6 +25,7 @@ type AdditionalResultsState = { pages: number; isLoading: boolean; nextPage: number; + upsertedImageCount: number; }; export const initialResultsState = @@ -34,6 +34,7 @@ export const initialResultsState = pages: 0, isLoading: false, nextPage: 0, + upsertedImageCount: 0, }); export type ResultsState = typeof initialResultsState; @@ -42,7 +43,10 @@ const resultsSlice = createSlice({ name: 'results', initialState: initialResultsState, reducers: { - resultAdded: resultsAdapter.upsertOne, + resultUpserted: (state, action: PayloadAction) => { + resultsAdapter.upsertOne(state, action.payload); + state.upsertedImageCount += 1; + }, }, extraReducers: (builder) => { /** @@ -68,47 +72,6 @@ const resultsSlice = createSlice({ state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; state.isLoading = false; }); - - /** - * Image Metadata Received - FULFILLED - */ - builder.addCase(imageMetadataReceived.fulfilled, (state, action) => { - const { image_type } = action.payload; - - if (image_type === 'results') { - resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO); - } - }); - - /** - * Image URLs Received - FULFILLED - */ - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_type, image_url, thumbnail_url } = - action.payload; - - if (image_type === 'results') { - resultsAdapter.updateOne(state, { - id: image_name, - changes: { - image_url: image_url, - thumbnail_url: thumbnail_url, - }, - }); - } - }); - - /** - * Delete Image - PENDING - * Pre-emptively remove the image from the gallery - */ - builder.addCase(imageDeleted.pending, (state, action) => { - const { imageType, imageName } = action.meta.arg; - - if (imageType === 'results') { - resultsAdapter.removeOne(state, imageName); - } - }); }, }); @@ -120,6 +83,6 @@ export const { selectTotal: selectResultsTotal, } = resultsAdapter.getSelectors((state) => state.results); -export const { resultAdded } = resultsSlice.actions; +export const { resultUpserted } = resultsSlice.actions; export default resultsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts index 5e458503ec..3058e82673 100644 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts @@ -1,11 +1,14 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { + PayloadAction, + createEntityAdapter, + createSlice, +} from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { receivedUploadImagesPage, IMAGES_PER_PAGE, } from 'services/thunks/gallery'; -import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; import { ImageDTO } from 'services/api'; import { dateComparator } from 'common/util/dateComparator'; @@ -23,6 +26,7 @@ type AdditionalUploadsState = { pages: number; isLoading: boolean; nextPage: number; + upsertedImageCount: number; }; export const initialUploadsState = @@ -31,6 +35,7 @@ export const initialUploadsState = pages: 0, nextPage: 0, isLoading: false, + upsertedImageCount: 0, }); export type UploadsState = typeof initialUploadsState; @@ -39,7 +44,10 @@ const uploadsSlice = createSlice({ name: 'uploads', initialState: initialUploadsState, reducers: { - uploadAdded: uploadsAdapter.upsertOne, + uploadUpserted: (state, action: PayloadAction) => { + uploadsAdapter.upsertOne(state, action.payload); + state.upsertedImageCount += 1; + }, }, extraReducers: (builder) => { /** @@ -65,36 +73,6 @@ const uploadsSlice = createSlice({ state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; state.isLoading = false; }); - - /** - * Image URLs Received - FULFILLED - */ - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_type, image_url, thumbnail_url } = - action.payload; - - if (image_type === 'uploads') { - uploadsAdapter.updateOne(state, { - id: image_name, - changes: { - image_url: image_url, - thumbnail_url: thumbnail_url, - }, - }); - } - }); - - /** - * Delete Image - pending - * Pre-emptively remove the image from the gallery - */ - builder.addCase(imageDeleted.pending, (state, action) => { - const { imageType, imageName } = action.meta.arg; - - if (imageType === 'uploads') { - uploadsAdapter.removeOne(state, imageName); - } - }); }, }); @@ -106,6 +84,6 @@ export const { selectTotal: selectUploadsTotal, } = uploadsAdapter.getSelectors((state) => state.uploads); -export const { uploadAdded } = uploadsSlice.actions; +export const { uploadUpserted } = uploadsSlice.actions; export default uploadsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 9527708c40..346261fbff 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -7,6 +7,7 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; +import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -97,6 +98,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'control' && template.type === 'control') { + return ( + + ); + } + if (type === 'model' && template.type === 'model') { return ( +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(ControlInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 7e4dadc21d..a9ae209178 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500; export const FIELD_TYPE_MAP: Record = { integer: 'integer', + float: 'float', number: 'float', string: 'string', boolean: 'boolean', @@ -15,6 +16,8 @@ export const FIELD_TYPE_MAP: Record = { array: 'array', item: 'item', ColorField: 'color', + ControlField: 'control', + control: 'control', }; const COLOR_TOKEN_VALUE = 500; @@ -22,6 +25,9 @@ const COLOR_TOKEN_VALUE = 500; const getColorTokenCssVariable = (color: string) => `var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`; +// @ts-ignore +// @ts-ignore +// @ts-ignore export const FIELDS: Record = { integer: { color: 'red', @@ -71,6 +77,12 @@ export const FIELDS: Record = { title: 'Conditioning', description: 'Conditioning may be passed between nodes.', }, + control: { + color: 'cyan', + colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left + title: 'Control', + description: 'Control info passed between nodes.', + }, model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index efb4a5518d..745584f244 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -61,6 +61,7 @@ export type FieldType = | 'image' | 'latents' | 'conditioning' + | 'control' | 'model' | 'array' | 'item' @@ -82,6 +83,7 @@ export type InputFieldValue = | ImageInputFieldValue | LatentsInputFieldValue | ConditioningInputFieldValue + | ControlInputFieldValue | EnumInputFieldValue | ModelInputFieldValue | ArrayInputFieldValue @@ -102,6 +104,7 @@ export type InputFieldTemplate = | ImageInputFieldTemplate | LatentsInputFieldTemplate | ConditioningInputFieldTemplate + | ControlInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate | ArrayInputFieldTemplate @@ -177,6 +180,11 @@ export type LatentsInputFieldValue = FieldValueBase & { export type ConditioningInputFieldValue = FieldValueBase & { type: 'conditioning'; + value?: string; +}; + +export type ControlInputFieldValue = FieldValueBase & { + type: 'control'; value?: undefined; }; @@ -262,6 +270,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { type: 'conditioning'; }; +export type ControlInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'control'; +}; + export type EnumInputFieldTemplate = InputFieldTemplateBase & { default: string | number; type: 'enum'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 11f0087488..e1f65e8826 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -10,6 +10,7 @@ import { IntegerInputFieldTemplate, LatentsInputFieldTemplate, ConditioningInputFieldTemplate, + ControlInputFieldTemplate, StringInputFieldTemplate, ModelInputFieldTemplate, ArrayInputFieldTemplate, @@ -215,6 +216,21 @@ const buildConditioningInputFieldTemplate = ({ return template; }; +const buildControlInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlInputFieldTemplate => { + const template: ControlInputFieldTemplate = { + ...baseField, + type: 'control', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildEnumInputFieldTemplate = ({ schemaObject, baseField, @@ -286,9 +302,20 @@ export const getFieldType = ( if (typeHints && name in typeHints) { rawFieldType = typeHints[name]; } else if (!schemaObject.type) { - rawFieldType = refObjectToFieldType( - schemaObject.allOf![0] as OpenAPIV3.ReferenceObject - ); + // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf + if (schemaObject.allOf) { + rawFieldType = refObjectToFieldType( + schemaObject.allOf![0] as OpenAPIV3.ReferenceObject + ); + } else if (schemaObject.anyOf) { + rawFieldType = refObjectToFieldType( + schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject + ); + } else if (schemaObject.oneOf) { + rawFieldType = refObjectToFieldType( + schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject + ); + } } else if (schemaObject.enum) { rawFieldType = 'enum'; } else if (schemaObject.type) { @@ -331,6 +358,9 @@ export const buildInputFieldTemplate = ( if (['conditioning'].includes(fieldType)) { return buildConditioningInputFieldTemplate({ schemaObject, baseField }); } + if (['control'].includes(fieldType)) { + return buildControlInputFieldTemplate({ schemaObject, baseField }); + } if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 9221e5f7ac..0b10a3e464 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -52,6 +52,10 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'control') { + fieldValue.value = undefined; + } + if (template.type === 'model') { fieldValue.value = undefined; } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts index cbe16abe28..51f89e8f74 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts @@ -11,7 +11,7 @@ import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes'; const POSITIVE_CONDITIONING = 'positive_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning'; const TEXT_TO_LATENTS = 'text_to_latents'; -const LATENTS_TO_IMAGE = 'latnets_to_image'; +const LATENTS_TO_IMAGE = 'latents_to_image'; /** * Builds the Text to Image tab graph. diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index ddd19b8749..631552414d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -13,7 +13,9 @@ import { buildOutputFieldTemplates, } from './fieldTemplateBuilders'; -const invocationDenylist = ['Graph']; +const RESERVED_FIELD_NAMES = ['id', 'type', 'meta']; + +const invocationDenylist = ['Graph', 'InvocationMeta']; export const parseSchema = (openAPI: OpenAPIV3.Document) => { // filter out non-invocation schemas, plus some tricky invocations for now @@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { (inputsAccumulator, property, propertyName) => { if ( // `type` and `id` are not valid inputs/outputs - !['type', 'id'].includes(propertyName) && + !RESERVED_FIELD_NAMES.includes(propertyName) && isSchemaObject(property) ) { const field: InputFieldTemplate | undefined = diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts index ad9985b5de..27ae63e5dd 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts @@ -21,8 +21,8 @@ export const useParameters = () => { * Sets prompt with toast */ const recallPrompt = useCallback( - (prompt: unknown) => { - if (!isString(prompt)) { + (prompt: unknown, negativePrompt?: unknown) => { + if (!isString(prompt) || !isString(negativePrompt)) { toaster({ title: t('toast.promptNotSet'), description: t('toast.promptNotSetDesc'), @@ -33,7 +33,7 @@ export const useParameters = () => { return; } - setBothPrompts(prompt); + setBothPrompts(prompt, negativePrompt); toaster({ title: t('toast.promptSet'), status: 'info', @@ -112,12 +112,13 @@ export const useParameters = () => { const recallAllParameters = useCallback( (image: ImageDTO | undefined) => { const type = image?.metadata?.type; - if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) { + // not sure what this list should be + if (['t2l', 'l2l', 'inpaint'].includes(String(type))) { dispatch(allParametersSet(image)); - if (image?.metadata?.type === 'img2img') { + if (image?.metadata?.type === 'l2l') { dispatch(setActiveTab('img2img')); - } else if (image?.metadata?.type === 'txt2img') { + } else if (image?.metadata?.type === 't2l') { dispatch(setActiveTab('txt2img')); } diff --git a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts b/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts index 2a6a832720..3fee0bcdd8 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts @@ -12,15 +12,8 @@ const useSetBothPrompts = () => { const dispatch = useAppDispatch(); return useCallback( - (inputPrompt: InvokeAI.Prompt) => { - const promptString = - typeof inputPrompt === 'string' - ? inputPrompt - : promptToString(inputPrompt); - - const [prompt, negativePrompt] = getPromptAndNegative(promptString); - - dispatch(setPositivePrompt(prompt)); + (inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => { + dispatch(setPositivePrompt(inputPrompt)); dispatch(setNegativePrompt(negativePrompt)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index f5054f1969..849f848ff3 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -52,7 +52,7 @@ export const initialGenerationState: GenerationState = { perlin: 0, positivePrompt: '', negativePrompt: '', - scheduler: 'lms', + scheduler: 'euler', seamBlur: 16, seamSize: 96, seamSteps: 30, diff --git a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts b/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts index d6d1af0f8e..8f06c7d0ef 100644 --- a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts +++ b/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts @@ -7,19 +7,29 @@ export const setAllParametersReducer = ( state: Draft, action: PayloadAction ) => { - const node = action.payload?.metadata.invokeai?.node; + const metadata = action.payload?.metadata; - if (!node) { + if (!metadata) { return; } + // not sure what this list should be if ( - node.type === 'txt2img' || - node.type === 'img2img' || - node.type === 'inpaint' + metadata.type === 't2l' || + metadata.type === 'l2l' || + metadata.type === 'inpaint' ) { - const { cfg_scale, height, model, prompt, scheduler, seed, steps, width } = - node; + const { + cfg_scale, + height, + model, + positive_conditioning, + negative_conditioning, + scheduler, + seed, + steps, + width, + } = metadata; if (cfg_scale !== undefined) { state.cfgScale = Number(cfg_scale); @@ -30,8 +40,11 @@ export const setAllParametersReducer = ( if (model !== undefined) { state.model = String(model); } - if (prompt !== undefined) { - state.positivePrompt = String(prompt); + if (positive_conditioning !== undefined) { + state.positivePrompt = String(positive_conditioning); + } + if (negative_conditioning !== undefined) { + state.negativePrompt = String(negative_conditioning); } if (scheduler !== undefined) { const schedulerString = String(scheduler); @@ -51,8 +64,8 @@ export const setAllParametersReducer = ( } } - if (node.type === 'img2img') { - const { fit, image } = node as ImageToImageInvocation; + if (metadata.type === 'l2l') { + const { fit, image } = metadata as ImageToImageInvocation; if (fit !== undefined) { state.shouldFitToWidthHeight = Boolean(fit); diff --git a/invokeai/frontend/web/src/features/system/store/actions.ts b/invokeai/frontend/web/src/features/system/store/actions.ts new file mode 100644 index 0000000000..66181bc803 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/actions.ts @@ -0,0 +1,3 @@ +import { createAction } from '@reduxjs/toolkit'; + +export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke'); diff --git a/invokeai/frontend/web/src/features/system/store/sessionSlice.ts b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts new file mode 100644 index 0000000000..40d59c7baa --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts @@ -0,0 +1,62 @@ +// TODO: split system slice inot this + +// import type { PayloadAction } from '@reduxjs/toolkit'; +// import { createSlice } from '@reduxjs/toolkit'; +// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions'; + +// export type SessionState = { +// /** +// * The current socket session id +// */ +// sessionId: string; +// /** +// * Whether the current session is a canvas session. Needed to manage the staging area. +// */ +// isCanvasSession: boolean; +// /** +// * When a session is canceled, its ID is stored here until a new session is created. +// */ +// canceledSessionId: string; +// }; + +// export const initialSessionState: SessionState = { +// sessionId: '', +// isCanvasSession: false, +// canceledSessionId: '', +// }; + +// export const sessionSlice = createSlice({ +// name: 'session', +// initialState: initialSessionState, +// reducers: { +// sessionIdChanged: (state, action: PayloadAction) => { +// state.sessionId = action.payload; +// }, +// isCanvasSessionChanged: (state, action: PayloadAction) => { +// state.isCanvasSession = action.payload; +// }, +// }, +// extraReducers: (builder) => { +// /** +// * Socket Subscribed +// */ +// builder.addCase(socketSubscribed, (state, action) => { +// state.sessionId = action.payload.sessionId; +// state.canceledSessionId = ''; +// }); + +// /** +// * Socket Unsubscribed +// */ +// builder.addCase(socketUnsubscribed, (state) => { +// state.sessionId = ''; +// }); +// }, +// }); + +// export const { sessionIdChanged, isCanvasSessionChanged } = +// sessionSlice.actions; + +// export default sessionSlice.reducer; + +export default {}; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 7331fcdba9..403fd60501 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,5 +1,5 @@ import { UseToastOptions } from '@chakra-ui/react'; -import type { PayloadAction } from '@reduxjs/toolkit'; +import { PayloadAction, isAnyOf } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; import { @@ -16,7 +16,11 @@ import { import { ProgressImage } from 'services/events/types'; import { makeToast } from '../../../app/components/Toaster'; -import { sessionCanceled, sessionInvoked } from 'services/thunks/session'; +import { + sessionCanceled, + sessionCreated, + sessionInvoked, +} from 'services/thunks/session'; import { receivedModels } from 'services/thunks/model'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; import { LogLevelName } from 'roarr'; @@ -215,6 +219,9 @@ export const systemSlice = createSlice({ languageChanged: (state, action: PayloadAction) => { state.language = action.payload; }, + progressImageSet(state, action: PayloadAction) { + state.progressImage = action.payload; + }, }, extraReducers(builder) { /** @@ -305,7 +312,6 @@ export const systemSlice = createSlice({ state.currentStep = 0; state.totalSteps = 0; state.statusTranslationKey = 'common.statusProcessingComplete'; - state.progressImage = null; if (state.canceledSession === data.graph_execution_state_id) { state.isProcessing = false; @@ -343,15 +349,8 @@ export const systemSlice = createSlice({ state.statusTranslationKey = 'common.statusPreparing'; }); - builder.addCase(sessionInvoked.rejected, (state, action) => { - const error = action.payload as string | undefined; - state.toastQueue.push( - makeToast({ title: error || t('toast.serverError'), status: 'error' }) - ); - }); - /** - * Session Canceled + * Session Canceled - FULFILLED */ builder.addCase(sessionCanceled.fulfilled, (state, action) => { state.canceledSession = action.meta.arg.sessionId; @@ -414,6 +413,26 @@ export const systemSlice = createSlice({ builder.addCase(imageUploaded.fulfilled, (state) => { state.isUploading = false; }); + + // *** Matchers - must be after all cases *** + + /** + * Session Invoked - REJECTED + * Session Created - REJECTED + */ + builder.addMatcher(isAnySessionRejected, (state, action) => { + state.isProcessing = false; + state.isCancelable = false; + state.isCancelScheduled = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusConnected'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ title: t('toast.serverError'), status: 'error' }) + ); + }); }, }); @@ -438,6 +457,12 @@ export const { isPersistedChanged, shouldAntialiasProgressImageChanged, languageChanged, + progressImageSet, } = systemSlice.actions; export default systemSlice.reducer; + +const isAnySessionRejected = isAnyOf( + sessionCreated.rejected, + sessionInvoked.rejected +); diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index ecf8621ed6..e75aeac6cb 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -7,7 +7,6 @@ export { OpenAPI } from './core/OpenAPI'; export type { OpenAPIConfig } from './core/OpenAPI'; export type { AddInvocation } from './models/AddInvocation'; -export type { BlurInvocation } from './models/BlurInvocation'; export type { Body_upload_image } from './models/Body_upload_image'; export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CollectInvocation } from './models/CollectInvocation'; @@ -17,7 +16,6 @@ export type { CompelInvocation } from './models/CompelInvocation'; export type { CompelOutput } from './models/CompelOutput'; export type { ConditioningField } from './models/ConditioningField'; export type { CreateModelRequest } from './models/CreateModelRequest'; -export type { CropImageInvocation } from './models/CropImageInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DivideInvocation } from './models/DivideInvocation'; @@ -28,11 +26,20 @@ export type { GraphExecutionState } from './models/GraphExecutionState'; export type { GraphInvocation } from './models/GraphInvocation'; export type { GraphInvocationOutput } from './models/GraphInvocationOutput'; export type { HTTPValidationError } from './models/HTTPValidationError'; +export type { ImageBlurInvocation } from './models/ImageBlurInvocation'; export type { ImageCategory } from './models/ImageCategory'; +export type { ImageChannelInvocation } from './models/ImageChannelInvocation'; +export type { ImageConvertInvocation } from './models/ImageConvertInvocation'; +export type { ImageCropInvocation } from './models/ImageCropInvocation'; export type { ImageDTO } from './models/ImageDTO'; export type { ImageField } from './models/ImageField'; +export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation'; +export type { ImageLerpInvocation } from './models/ImageLerpInvocation'; export type { ImageMetadata } from './models/ImageMetadata'; +export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation'; export type { ImageOutput } from './models/ImageOutput'; +export type { ImagePasteInvocation } from './models/ImagePasteInvocation'; +export type { ImageRecordChanges } from './models/ImageRecordChanges'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; export type { ImageType } from './models/ImageType'; @@ -43,14 +50,12 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation'; export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntOutput } from './models/IntOutput'; -export type { InverseLerpInvocation } from './models/InverseLerpInvocation'; export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { LatentsField } from './models/LatentsField'; export type { LatentsOutput } from './models/LatentsOutput'; export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation'; export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation'; -export type { LerpInvocation } from './models/LerpInvocation'; export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; @@ -61,7 +66,6 @@ export type { NoiseOutput } from './models/NoiseOutput'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; -export type { PasteImageInvocation } from './models/PasteImageInvocation'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; diff --git a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts index 1ff7b010c2..e9671a918f 100644 --- a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts @@ -10,6 +10,10 @@ export type AddInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'add'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts new file mode 100644 index 0000000000..474f1d3f3c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Canny edge detection for ControlNet + */ +export type CannyImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'canny_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * low threshold of Canny pixel gradient + */ + low_threshold?: number; + /** + * high threshold of Canny pixel gradient + */ + high_threshold?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts index d250ae4450..f190ab7073 100644 --- a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts @@ -10,6 +10,10 @@ export type CollectInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'collect'; /** * The item to collect (all inputs must be of the same type) diff --git a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts index f03d53a841..1dc390c1be 100644 --- a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts @@ -10,6 +10,10 @@ export type CompelInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'compel'; /** * Prompt diff --git a/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts new file mode 100644 index 0000000000..4a07508be7 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies content shuffle processing to image + */ +export type ContentShuffleImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'content_shuffle_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * content shuffle h parameter + */ + 'h'?: number; + /** + * content shuffle w parameter + */ + 'w'?: number; + /** + * cont + */ + 'f'?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlField.ts b/invokeai/frontend/web/src/services/api/models/ControlField.ts new file mode 100644 index 0000000000..4f493d4410 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlField.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +export type ControlField = { + /** + * processed image + */ + image: ImageField; + /** + * control model used + */ + control_model: string; + /** + * weight given to controlnet + */ + control_weight: number; + /** + * % of total steps at which controlnet is first applied + */ + begin_step_percent: number; + /** + * % of total steps at which controlnet is last applied + */ + end_step_percent: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts new file mode 100644 index 0000000000..e8372f43dd --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Collects ControlNet info to pass to other nodes + */ +export type ControlNetInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'controlnet'; + /** + * image to process + */ + image?: ImageField; + /** + * control model used + */ + control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace'; + /** + * weight given to controlnet + */ + control_weight?: number; + /** + * % of total steps at which controlnet is first applied + */ + begin_step_percent?: number; + /** + * % of total steps at which controlnet is last applied + */ + end_step_percent?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlOutput.ts b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts new file mode 100644 index 0000000000..43f1b3341c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts @@ -0,0 +1,17 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ControlField } from './ControlField'; + +/** + * node output for ControlNet info + */ +export type ControlOutput = { + type?: 'control_output'; + /** + * The control info dict + */ + control?: ControlField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts index 19342acf8f..874df93c30 100644 --- a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts @@ -12,6 +12,10 @@ export type CvInpaintInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'cv_inpaint'; /** * The image to inpaint diff --git a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts index 3cb262e9af..fd5b3475ae 100644 --- a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts @@ -10,6 +10,10 @@ export type DivideInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'div'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index 039923e585..6be925841b 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -3,31 +3,34 @@ /* eslint-disable */ import type { AddInvocation } from './AddInvocation'; -import type { BlurInvocation } from './BlurInvocation'; import type { CollectInvocation } from './CollectInvocation'; import type { CompelInvocation } from './CompelInvocation'; -import type { CropImageInvocation } from './CropImageInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { DivideInvocation } from './DivideInvocation'; import type { Edge } from './Edge'; import type { GraphInvocation } from './GraphInvocation'; +import type { ImageBlurInvocation } from './ImageBlurInvocation'; +import type { ImageChannelInvocation } from './ImageChannelInvocation'; +import type { ImageConvertInvocation } from './ImageConvertInvocation'; +import type { ImageCropInvocation } from './ImageCropInvocation'; +import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation'; +import type { ImageLerpInvocation } from './ImageLerpInvocation'; +import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation'; +import type { ImagePasteInvocation } from './ImagePasteInvocation'; import type { ImageToImageInvocation } from './ImageToImageInvocation'; import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation'; import type { InfillColorInvocation } from './InfillColorInvocation'; import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation'; import type { InfillTileInvocation } from './InfillTileInvocation'; import type { InpaintInvocation } from './InpaintInvocation'; -import type { InverseLerpInvocation } from './InverseLerpInvocation'; import type { IterateInvocation } from './IterateInvocation'; import type { LatentsToImageInvocation } from './LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation'; -import type { LerpInvocation } from './LerpInvocation'; import type { LoadImageInvocation } from './LoadImageInvocation'; import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation'; import type { MultiplyInvocation } from './MultiplyInvocation'; import type { NoiseInvocation } from './NoiseInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; -import type { PasteImageInvocation } from './PasteImageInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -49,7 +52,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts index 5109a49a68..8512faae74 100644 --- a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts @@ -5,14 +5,17 @@ import type { Graph } from './Graph'; /** - * A node to process inputs and produce outputs. - * May use dependency injection in __init__ to receive providers. + * Execute a graph */ export type GraphInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'graph'; /** * The graph to run diff --git a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts new file mode 100644 index 0000000000..6dea43dc32 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies HED edge detection to image + */ +export type HedImageprocessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'hed_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use scribble mode + */ + scribble?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts similarity index 72% rename from invokeai/frontend/web/src/services/api/models/BlurInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts index 0643e4b309..3ba86d8fab 100644 --- a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Blurs an image */ -export type BlurInvocation = { +export type ImageBlurInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'blur'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_blur'; /** * The image to blur */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts index c4edf90fd3..6b04a0b864 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts @@ -5,4 +5,4 @@ /** * The category of an image. Use ImageCategory.OTHER for non-default categories. */ -export type ImageCategory = 'general' | 'control' | 'other'; +export type ImageCategory = 'general' | 'control' | 'mask' | 'other'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts new file mode 100644 index 0000000000..47bfd4110f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Gets a channel from an image. + */ +export type ImageChannelInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_chan'; + /** + * The image to get the channel from + */ + image?: ImageField; + /** + * The channel to get + */ + channel?: 'A' | 'R' | 'G' | 'B'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts new file mode 100644 index 0000000000..4bd59d03b0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Converts an image to a different mode. + */ +export type ImageConvertInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_conv'; + /** + * The image to convert + */ + image?: ImageField; + /** + * The mode to convert to + */ + mode?: 'L' | 'RGB' | 'RGBA' | 'CMYK' | 'YCbCr' | 'LAB' | 'HSV' | 'I' | 'F'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts similarity index 80% rename from invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts index 2676f5cb87..5207ebbf6d 100644 --- a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Crops an image to a specified box. The box can be outside of the image. */ -export type CropImageInvocation = { +export type ImageCropInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'crop'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_crop'; /** * The image to crop */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index c5377b4c76..bc2f19f1b5 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -50,6 +50,10 @@ export type ImageDTO = { * The deleted timestamp of the image. */ deleted_at?: string; + /** + * Whether this is an intermediate image. + */ + is_intermediate: boolean; /** * The session ID that generated this image, if it is a generated image. */ diff --git a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts similarity index 73% rename from invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts index 33c59b7bac..0347d4dc38 100644 --- a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Inverse linear interpolation of all pixels of an image */ -export type InverseLerpInvocation = { +export type ImageInverseLerpInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'ilerp'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_ilerp'; /** * The image to lerp */ diff --git a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts similarity index 74% rename from invokeai/frontend/web/src/services/api/models/LerpInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts index f2406c2246..388c86061c 100644 --- a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Linear interpolation of all pixels of an image */ -export type LerpInvocation = { +export type ImageLerpInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'lerp'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_lerp'; /** * The image to lerp */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts new file mode 100644 index 0000000000..751ee49158 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Multiplies two images together using `PIL.ImageChops.multiply()`. + */ +export type ImageMultiplyInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_mul'; + /** + * The first image to multiply + */ + image1?: ImageField; + /** + * The second image to multiply + */ + image2?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts similarity index 79% rename from invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts index 8a181ccf07..c883b9a5d8 100644 --- a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Pastes an image into another image. */ -export type PasteImageInvocation = { +export type ImagePasteInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'paste'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_paste'; /** * The base image */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts new file mode 100644 index 0000000000..90639a0569 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts @@ -0,0 +1,21 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Base class for invocations that preprocess images for ControlNet + */ +export type ImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'image_processor'; + /** + * image to process + */ + image?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts new file mode 100644 index 0000000000..51f0ee2079 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts @@ -0,0 +1,24 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageCategory } from './ImageCategory'; + +/** + * A set of changes to apply to an image record. + * + * Only limited changes are valid: + * - `image_category`: change the category of an image + * - `session_id`: change the session associated with an image + */ +export type ImageRecordChanges = { + /** + * The image's new category. + */ + image_category?: ImageCategory; + /** + * The image's new session ID. + */ + session_id?: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts index fb43c76921..e63ec93ada 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts @@ -12,6 +12,10 @@ export type ImageToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'img2img'; /** * The prompt to generate an image from @@ -45,6 +49,18 @@ export type ImageToImageInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; /** * The input image */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts index f72d446615..5569c2fa86 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ImageToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'i2l'; /** * The image to encode diff --git a/invokeai/frontend/web/src/services/api/models/ImageType.ts b/invokeai/frontend/web/src/services/api/models/ImageType.ts index bba9134e63..dfc10bf455 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageType.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageType.ts @@ -5,4 +5,4 @@ /** * The type of an image. */ -export type ImageType = 'results' | 'uploads' | 'intermediates'; +export type ImageType = 'results' | 'uploads'; diff --git a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts index 157c976e11..3e637b299c 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts @@ -13,6 +13,10 @@ export type InfillColorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_rgba'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts index a4c18ade5d..325bfe2080 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts @@ -12,6 +12,10 @@ export type InfillPatchMatchInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_patchmatch'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts index 12113f57f5..dfb1cbc61d 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts @@ -12,6 +12,10 @@ export type InfillTileInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_tile'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts index 88ead9907c..b8ed268ef9 100644 --- a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts @@ -13,6 +13,10 @@ export type InpaintInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'inpaint'; /** * The prompt to generate an image from @@ -46,6 +50,18 @@ export type InpaintInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; /** * The input image */ diff --git a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts index 0ff7a1258d..15bf92dfea 100644 --- a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts @@ -3,14 +3,17 @@ /* eslint-disable */ /** - * A node to process inputs and produce outputs. - * May use dependency injection in __init__ to receive providers. + * Iterates over a list of items */ export type IterateInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'iterate'; /** * The list of items to iterate over diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts index 8acd872e28..fcaa37d7e8 100644 --- a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts @@ -12,6 +12,10 @@ export type LatentsToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'l2i'; /** * The latents to generate an image from diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts index 29995c6ad9..6436557f64 100644 --- a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts @@ -13,6 +13,10 @@ export type LatentsToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'l2l'; /** * Positive conditioning for generation @@ -42,14 +46,6 @@ export type LatentsToLatentsInvocation = { * The model to use (currently ignored) */ model?: string; - /** - * Whether or not to generate an image that can tile without seams - */ - seamless?: boolean; - /** - * The axes to tile the image on, 'x' and/or 'y' - */ - seamless_axes?: string; /** * The latents to use as a base image */ diff --git a/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts new file mode 100644 index 0000000000..a9bdab56ec --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies line art anime processing to image + */ +export type LineartAnimeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'lineart_anime_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts new file mode 100644 index 0000000000..1aa931525f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies line art processing to image + */ +export type LineartImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'lineart_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use coarse mode + */ + coarse?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts index 745a9b44e4..f20d983f9b 100644 --- a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ImageField } from './ImageField'; /** * Load an image and provide it as output. @@ -12,14 +12,14 @@ export type LoadImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'load_image'; /** - * The type of the image + * The image to load */ - image_type: ImageType; - /** - * The name of the image - */ - image_name: string; + image?: ImageField; }; diff --git a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts index e71b1f464b..e3693f6d98 100644 --- a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts @@ -12,6 +12,10 @@ export type MaskFromAlphaInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'tomask'; /** * The image to create the mask from diff --git a/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..71283b0614 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Midas depth processing to image + */ +export type MidasDepthImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'midas_depth_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * Midas parameter a = amult * PI + */ + a_mult?: number; + /** + * Midas parameter bg_th + */ + bg_th?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts new file mode 100644 index 0000000000..85a2ad15cc --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies MLSD processing to image + */ +export type MlsdImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'mlsd_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * MLSD parameter thr_v + */ + thr_v?: number; + /** + * MLSD parameter thr_d + */ + thr_d?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts index eede8f18d7..9fd716f33d 100644 --- a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts @@ -10,6 +10,10 @@ export type MultiplyInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'mul'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts index 59e50b76f3..239a24bfe5 100644 --- a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts @@ -10,6 +10,10 @@ export type NoiseInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'noise'; /** * The seed to use diff --git a/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts new file mode 100644 index 0000000000..519ea7a89d --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies NormalBae processing to image + */ +export type NormalbaeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'normalbae_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts new file mode 100644 index 0000000000..44947df15b --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Openpose processing to image + */ +export type OpenposeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'openpose_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * whether to use hands and face mode + */ + hand_and_face?: boolean; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts index 7047310a87..7a45d0a0ac 100644 --- a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts @@ -10,6 +10,10 @@ export type ParamIntInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'param_int'; /** * The integer value diff --git a/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts new file mode 100644 index 0000000000..59076cb2e1 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies PIDI processing to image + */ +export type PidiImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'pidi_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use safe mode + */ + safe?: boolean; + /** + * whether to use scribble mode + */ + scribble?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts index 0a5220c31d..a2f7c2f02a 100644 --- a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts @@ -10,6 +10,10 @@ export type RandomIntInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'rand_int'; /** * The inclusive low value diff --git a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts index c1f80042a6..925511578d 100644 --- a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts @@ -10,6 +10,10 @@ export type RandomRangeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'random_range'; /** * The inclusive low value diff --git a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts index 1c37ca7fe3..3681602a95 100644 --- a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts @@ -10,6 +10,10 @@ export type RangeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'range'; /** * The start of the range diff --git a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts index b918f17130..7dfac68d39 100644 --- a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts @@ -10,6 +10,10 @@ export type RangeOfSizeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'range_of_size'; /** * The start of the range diff --git a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts index c0fabb4984..9a7b6c61e4 100644 --- a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ResizeLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lresize'; /** * The latents to resize diff --git a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts index e03ed01c81..0bacb5d805 100644 --- a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts @@ -12,6 +12,10 @@ export type RestoreFaceInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'restore_face'; /** * The input image diff --git a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts index f398eaf408..506b21e540 100644 --- a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ScaleLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lscale'; /** * The latents to scale diff --git a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts index 145895ad75..1b73055584 100644 --- a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts @@ -12,6 +12,10 @@ export type ShowImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'show_image'; /** * The image to show diff --git a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts index 6f2da116a2..23334bd891 100644 --- a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts @@ -10,6 +10,10 @@ export type SubtractInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'sub'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts index 184e35693b..7128ea8440 100644 --- a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts @@ -2,6 +2,8 @@ /* tslint:disable */ /* eslint-disable */ +import type { ImageField } from './ImageField'; + /** * Generates an image using text2img. */ @@ -10,6 +12,10 @@ export type TextToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'txt2img'; /** * The prompt to generate an image from @@ -43,5 +49,17 @@ export type TextToImageInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; }; diff --git a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts index d1ec5ed08c..33eedc0f02 100644 --- a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts @@ -13,6 +13,10 @@ export type TextToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 't2l'; /** * Positive conditioning for generation @@ -42,13 +46,5 @@ export type TextToLatentsInvocation = { * The model to use (currently ignored) */ model?: string; - /** - * Whether or not to generate an image that can tile without seams - */ - seamless?: boolean; - /** - * The axes to tile the image on, 'x' and/or 'y' - */ - seamless_axes?: string; }; diff --git a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts index 8416c2454d..d0aca63964 100644 --- a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts @@ -12,6 +12,10 @@ export type UpscaleInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'upscale'; /** * The input image diff --git a/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts new file mode 100644 index 0000000000..e2f1bc2111 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $CannyImageProcessorInvocation = { + description: `Canny edge detection for ControlNet`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + low_threshold: { + type: 'number', + description: `low threshold of Canny pixel gradient`, + }, + high_threshold: { + type: 'number', + description: `high threshold of Canny pixel gradient`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts new file mode 100644 index 0000000000..9c51fdecc0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts @@ -0,0 +1,43 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ContentShuffleImageProcessorInvocation = { + description: `Applies content shuffle processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + 'h': { + type: 'number', + description: `content shuffle h parameter`, + }, + 'w': { + type: 'number', + description: `content shuffle w parameter`, + }, + 'f': { + type: 'number', + description: `cont`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts new file mode 100644 index 0000000000..81292b8638 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlField = { + properties: { + image: { + type: 'all-of', + description: `processed image`, + contains: [{ + type: 'ImageField', + }], + isRequired: true, + }, + control_model: { + type: 'string', + description: `control model used`, + isRequired: true, + }, + control_weight: { + type: 'number', + description: `weight given to controlnet`, + isRequired: true, + }, + begin_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is first applied`, + isRequired: true, + maximum: 1, + }, + end_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is last applied`, + isRequired: true, + maximum: 1, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts new file mode 100644 index 0000000000..29ff507e66 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlNetInvocation = { + description: `Collects ControlNet info to pass to other nodes`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + control_model: { + type: 'Enum', + }, + control_weight: { + type: 'number', + description: `weight given to controlnet`, + maximum: 1, + }, + begin_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is first applied`, + maximum: 1, + }, + end_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is last applied`, + maximum: 1, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts new file mode 100644 index 0000000000..d94d633fca --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts @@ -0,0 +1,28 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlOutput = { + description: `node output for ControlNet info`, + properties: { + type: { + type: 'Enum', + }, + control: { + type: 'all-of', + description: `The control info dict`, + contains: [{ + type: 'ControlField', + }], + }, + width: { + type: 'number', + description: `The width of the noise in pixels`, + isRequired: true, + }, + height: { + type: 'number', + description: `The height of the noise in pixels`, + isRequired: true, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts new file mode 100644 index 0000000000..3cffa008f5 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $HedImageprocessorInvocation = { + description: `Applies HED edge detection to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + scribble: { + type: 'boolean', + description: `whether to use scribble mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts new file mode 100644 index 0000000000..36748982c5 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts @@ -0,0 +1,23 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ImageProcessorInvocation = { + description: `Base class for invocations that preprocess images for ControlNet`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts new file mode 100644 index 0000000000..63a9c8158c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $LineartAnimeImageProcessorInvocation = { + description: `Applies line art anime processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts new file mode 100644 index 0000000000..6ba4064823 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $LineartImageProcessorInvocation = { + description: `Applies line art processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + coarse: { + type: 'boolean', + description: `whether to use coarse mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..ea0b2b0099 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $MidasDepthImageProcessorInvocation = { + description: `Applies Midas depth processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + a_mult: { + type: 'number', + description: `Midas parameter a = amult * PI`, + }, + bg_th: { + type: 'number', + description: `Midas parameter bg_th`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts new file mode 100644 index 0000000000..1bff7579cc --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts @@ -0,0 +1,39 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $MlsdImageProcessorInvocation = { + description: `Applies MLSD processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + thr_v: { + type: 'number', + description: `MLSD parameter thr_v`, + }, + thr_d: { + type: 'number', + description: `MLSD parameter thr_d`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts new file mode 100644 index 0000000000..7cdfe6f3ae --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $NormalbaeImageProcessorInvocation = { + description: `Applies NormalBae processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts new file mode 100644 index 0000000000..2a187e9cf2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $OpenposeImageProcessorInvocation = { + description: `Applies Openpose processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + hand_and_face: { + type: 'boolean', + description: `whether to use hands and face mode`, + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts new file mode 100644 index 0000000000..0fd53967c2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts @@ -0,0 +1,39 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $PidiImageProcessorInvocation = { + description: `Applies PIDI processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + safe: { + type: 'boolean', + description: `whether to use safe mode`, + }, + scribble: { + type: 'boolean', + description: `whether to use scribble mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts new file mode 100644 index 0000000000..e5b0387d5a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts @@ -0,0 +1,16 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $RandomIntInvocation = { + description: `Outputs a single random integer.`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index 13b2ef836a..d01a97a45e 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -4,6 +4,7 @@ import type { Body_upload_image } from '../models/Body_upload_image'; import type { ImageCategory } from '../models/ImageCategory'; import type { ImageDTO } from '../models/ImageDTO'; +import type { ImageRecordChanges } from '../models/ImageRecordChanges'; import type { ImageType } from '../models/ImageType'; import type { ImageUrlsDTO } from '../models/ImageUrlsDTO'; import type { PaginatedResults_ImageDTO_ } from '../models/PaginatedResults_ImageDTO_'; @@ -65,20 +66,32 @@ export class ImagesService { * @throws ApiError */ public static uploadImage({ - imageType, formData, imageCategory, + isIntermediate = false, + sessionId, }: { - imageType: ImageType, formData: Body_upload_image, + /** + * The category of the image + */ imageCategory?: ImageCategory, + /** + * Whether this is an intermediate image + */ + isIntermediate?: boolean, + /** + * The session ID associated with this upload, if any + */ + sessionId?: string, }): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v1/images/', query: { - 'image_type': imageType, 'image_category': imageCategory, + 'is_intermediate': isIntermediate, + 'session_id': sessionId, }, formData: formData, mediaType: 'multipart/form-data', @@ -132,6 +145,9 @@ export class ImagesService { imageType, imageName, }: { + /** + * The type of image to delete + */ imageType: ImageType, /** * The name of the image to delete @@ -151,6 +167,42 @@ export class ImagesService { }); } + /** + * Update Image + * Updates an image + * @returns ImageDTO Successful Response + * @throws ApiError + */ + public static updateImage({ + imageType, + imageName, + requestBody, + }: { + /** + * The type of image to update + */ + imageType: ImageType, + /** + * The name of the image to update + */ + imageName: string, + requestBody: ImageRecordChanges, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v1/images/{image_type}/{image_name}', + path: { + 'image_type': imageType, + 'image_name': imageName, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + /** * Get Image Metadata * Gets an image's metadata diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 23597c9e9e..1c55d36502 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -2,34 +2,37 @@ /* tslint:disable */ /* eslint-disable */ import type { AddInvocation } from '../models/AddInvocation'; -import type { BlurInvocation } from '../models/BlurInvocation'; import type { CollectInvocation } from '../models/CollectInvocation'; import type { CompelInvocation } from '../models/CompelInvocation'; -import type { CropImageInvocation } from '../models/CropImageInvocation'; import type { CvInpaintInvocation } from '../models/CvInpaintInvocation'; import type { DivideInvocation } from '../models/DivideInvocation'; import type { Edge } from '../models/Edge'; import type { Graph } from '../models/Graph'; import type { GraphExecutionState } from '../models/GraphExecutionState'; import type { GraphInvocation } from '../models/GraphInvocation'; +import type { ImageBlurInvocation } from '../models/ImageBlurInvocation'; +import type { ImageChannelInvocation } from '../models/ImageChannelInvocation'; +import type { ImageConvertInvocation } from '../models/ImageConvertInvocation'; +import type { ImageCropInvocation } from '../models/ImageCropInvocation'; +import type { ImageInverseLerpInvocation } from '../models/ImageInverseLerpInvocation'; +import type { ImageLerpInvocation } from '../models/ImageLerpInvocation'; +import type { ImageMultiplyInvocation } from '../models/ImageMultiplyInvocation'; +import type { ImagePasteInvocation } from '../models/ImagePasteInvocation'; import type { ImageToImageInvocation } from '../models/ImageToImageInvocation'; import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation'; import type { InfillColorInvocation } from '../models/InfillColorInvocation'; import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation'; import type { InfillTileInvocation } from '../models/InfillTileInvocation'; import type { InpaintInvocation } from '../models/InpaintInvocation'; -import type { InverseLerpInvocation } from '../models/InverseLerpInvocation'; import type { IterateInvocation } from '../models/IterateInvocation'; import type { LatentsToImageInvocation } from '../models/LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvocation'; -import type { LerpInvocation } from '../models/LerpInvocation'; import type { LoadImageInvocation } from '../models/LoadImageInvocation'; import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation'; import type { MultiplyInvocation } from '../models/MultiplyInvocation'; import type { NoiseInvocation } from '../models/NoiseInvocation'; import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedResults_GraphExecutionState_'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; -import type { PasteImageInvocation } from '../models/PasteImageInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -151,7 +154,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -188,7 +191,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index bd1d60099a..f1eb844f2c 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -8,7 +8,7 @@ import { import { socketSubscribed, socketUnsubscribed } from './actions'; import { AppThunkDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionInvoked, sessionCreated } from 'services/thunks/session'; +import { sessionCreated } from 'services/thunks/session'; import { OpenAPI } from 'services/api'; import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; @@ -64,15 +64,9 @@ export const socketMiddleware = () => { if (sessionCreated.fulfilled.match(action)) { const sessionId = action.payload.id; - const sessionLog = socketioLog.child({ sessionId }); const oldSessionId = getState().system.sessionId; if (oldSessionId) { - sessionLog.debug( - { oldSessionId }, - `Unsubscribed from old session (${oldSessionId})` - ); - socket.emit('unsubscribe', { session: oldSessionId, }); @@ -85,8 +79,6 @@ export const socketMiddleware = () => { ); } - sessionLog.debug(`Subscribe to new session (${sessionId})`); - socket.emit('subscribe', { session: sessionId }); dispatch( @@ -95,9 +87,6 @@ export const socketMiddleware = () => { timestamp: getTimestamp(), }) ); - - // Finally we actually invoke the session, starting processing - dispatch(sessionInvoked({ sessionId })); } next(action); diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 4431a9fd8b..5262b26d1e 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -1,7 +1,6 @@ import { MiddlewareAPI } from '@reduxjs/toolkit'; import { AppDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCanceled } from 'services/thunks/session'; import { Socket } from 'socket.io-client'; import { generatorProgress, @@ -16,12 +15,6 @@ import { import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { Logger } from 'roarr'; import { JsonObject } from 'roarr/dist/types'; -import { - receivedResultImagesPage, - receivedUploadImagesPage, -} from 'services/thunks/gallery'; -import { receivedModels } from 'services/thunks/model'; -import { receivedOpenAPISchema } from 'services/thunks/schema'; import { makeToast } from '../../../app/components/Toaster'; import { addToast } from '../../../features/system/store/systemSlice'; @@ -43,37 +36,13 @@ export const setEventListeners = (arg: SetEventListenersArg) => { dispatch(socketConnected({ timestamp: getTimestamp() })); - const { results, uploads, models, nodes, config, system } = getState(); + const { sessionId } = getState().system; - const { disabledTabs } = config; - - // These thunks need to be dispatch in middleware; cannot handle in a reducer - if (!results.ids.length) { - dispatch(receivedResultImagesPage()); - } - - if (!uploads.ids.length) { - dispatch(receivedUploadImagesPage()); - } - - if (!models.ids.length) { - dispatch(receivedModels()); - } - - if (!nodes.schema && !disabledTabs.includes('nodes')) { - dispatch(receivedOpenAPISchema()); - } - - if (system.sessionId) { - log.debug( - { sessionId: system.sessionId }, - `Subscribed to existing session (${system.sessionId})` - ); - - socket.emit('subscribe', { session: system.sessionId }); + if (sessionId) { + socket.emit('subscribe', { session: sessionId }); dispatch( socketSubscribed({ - sessionId: system.sessionId, + sessionId, timestamp: getTimestamp(), }) ); @@ -101,7 +70,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Disconnect */ socket.on('disconnect', () => { - log.debug('Disconnected'); dispatch(socketDisconnected({ timestamp: getTimestamp() })); }); @@ -109,18 +77,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation started */ socket.on('invocation_started', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored invocation started (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation started (${data.node.type})` - ); dispatch(invocationStarted({ data, timestamp: getTimestamp() })); }); @@ -128,18 +84,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Generator progress */ socket.on('generator_progress', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored generator progress (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Generator progress (${data.node.type})` - ); dispatch(generatorProgress({ data, timestamp: getTimestamp() })); }); @@ -147,10 +91,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation error */ socket.on('invocation_error', (data) => { - log.error( - { data, sessionId: data.graph_execution_state_id }, - `Invocation error (${data.node.type})` - ); dispatch(invocationError({ data, timestamp: getTimestamp() })); }); @@ -158,19 +98,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation complete */ socket.on('invocation_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation complete (${data.node.type})` - ); - const sessionId = data.graph_execution_state_id; - - const { cancelType, isCancelScheduled } = getState().system; - - // Handle scheduled cancelation - if (cancelType === 'scheduled' && isCancelScheduled) { - dispatch(sessionCanceled({ sessionId })); - } - dispatch( invocationComplete({ data, @@ -183,10 +110,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Graph complete */ socket.on('graph_execution_state_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Graph execution state complete (${data.graph_execution_state_id})` - ); dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() })); }); }; diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts index 01e8a986b2..11960e00d2 100644 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ b/invokeai/frontend/web/src/services/thunks/gallery.ts @@ -1,45 +1,64 @@ -import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { ImagesService } from 'services/api'; +import { ImagesService, PaginatedResults_ImageDTO_ } from 'services/api'; export const IMAGES_PER_PAGE = 20; -const galleryLog = log.child({ namespace: 'gallery' }); +type ReceivedResultImagesPageThunkConfig = { + rejectValue: { + error: unknown; + }; +}; -export const receivedResultImagesPage = createAppAsyncThunk( +export const receivedResultImagesPage = createAppAsyncThunk< + PaginatedResults_ImageDTO_, + void, + ReceivedResultImagesPageThunkConfig +>( 'results/receivedResultImagesPage', async (_arg, { getState, rejectWithValue }) => { - const { page, pages, nextPage } = getState().results; + const { page, pages, nextPage, upsertedImageCount } = getState().results; - if (nextPage === page) { - rejectWithValue([]); - } + // If many images have been upserted, we need to offset the page number + // TODO: add an offset param to the list images endpoint + const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const response = await ImagesService.listImagesWithMetadata({ imageType: 'results', imageCategory: 'general', - page: getState().results.nextPage, + page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); - galleryLog.info({ response }, `Received ${response.items.length} results`); - return response; } ); -export const receivedUploadImagesPage = createAppAsyncThunk( +type ReceivedUploadImagesPageThunkConfig = { + rejectValue: { + error: unknown; + }; +}; + +export const receivedUploadImagesPage = createAppAsyncThunk< + PaginatedResults_ImageDTO_, + void, + ReceivedUploadImagesPageThunkConfig +>( 'uploads/receivedUploadImagesPage', - async (_arg, { getState }) => { + async (_arg, { getState, rejectWithValue }) => { + const { page, pages, nextPage, upsertedImageCount } = getState().uploads; + + // If many images have been upserted, we need to offset the page number + // TODO: add an offset param to the list images endpoint + const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); + const response = await ImagesService.listImagesWithMetadata({ imageType: 'uploads', imageCategory: 'general', - page: getState().uploads.nextPage, + page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); - galleryLog.info({ response }, `Received ${response.items.length} uploads`); - return response; } ); diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index 6831eb647d..f0c0456202 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,10 +1,6 @@ -import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { ImagesService } from 'services/api'; -import { getHeaders } from 'services/util/getHeaders'; - -const imagesLog = log.child({ namespace: 'image' }); type imageUrlsReceivedArg = Parameters< (typeof ImagesService)['getImageUrls'] @@ -17,7 +13,6 @@ export const imageUrlsReceived = createAppAsyncThunk( 'api/imageUrlsReceived', async (arg: imageUrlsReceivedArg) => { const response = await ImagesService.getImageUrls(arg); - imagesLog.info({ arg, response }, 'Received image urls'); return response; } ); @@ -33,7 +28,6 @@ export const imageMetadataReceived = createAppAsyncThunk( 'api/imageMetadataReceived', async (arg: imageMetadataReceivedArg) => { const response = await ImagesService.getImageMetadata(arg); - imagesLog.info({ arg, response }, 'Received image record'); return response; } ); @@ -53,11 +47,7 @@ export const imageUploaded = createAppAsyncThunk( // strip out `activeTabName` from arg - the route does not need it const { activeTabName, ...rest } = arg; const response = await ImagesService.uploadImage(rest); - const { location } = getHeaders(response); - - imagesLog.debug({ arg: '', response, location }, 'Image uploaded'); - - return { response, location }; + return response; } ); @@ -70,9 +60,19 @@ export const imageDeleted = createAppAsyncThunk( 'api/imageDeleted', async (arg: ImageDeletedArg) => { const response = await ImagesService.deleteImage(arg); - - imagesLog.debug({ arg, response }, 'Image deleted'); - + return response; + } +); + +type ImageUpdatedArg = Parameters<(typeof ImagesService)['updateImage']>[0]; + +/** + * `ImagesService.updateImage()` thunk + */ +export const imageUpdated = createAppAsyncThunk( + 'api/imageUpdated', + async (arg: ImageUpdatedArg) => { + const response = await ImagesService.updateImage(arg); return response; } ); diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts index dca4134886..cf87fb30f5 100644 --- a/invokeai/frontend/web/src/services/thunks/session.ts +++ b/invokeai/frontend/web/src/services/thunks/session.ts @@ -1,7 +1,7 @@ import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SessionsService } from 'services/api'; +import { GraphExecutionState, SessionsService } from 'services/api'; import { log } from 'app/logging/useLogger'; -import { serializeError } from 'serialize-error'; +import { isObject } from 'lodash-es'; const sessionLog = log.child({ namespace: 'session' }); @@ -11,99 +11,89 @@ type SessionCreatedArg = { >[0]['requestBody']; }; +type SessionCreatedThunkConfig = { + rejectValue: { arg: SessionCreatedArg; error: unknown }; +}; + /** * `SessionsService.createSession()` thunk */ -export const sessionCreated = createAppAsyncThunk( - 'api/sessionCreated', - async (arg: SessionCreatedArg, { rejectWithValue }) => { - try { - const response = await SessionsService.createSession({ - requestBody: arg.graph, - }); - sessionLog.info({ arg, response }, `Session created (${response.id})`); - return response; - } catch (err: any) { - sessionLog.error( - { - error: serializeError(err), - }, - 'Problem creating session' - ); - return rejectWithValue(err.message); - } - } -); - -type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0]; - -/** - * `SessionsService.addNode()` thunk - */ -export const nodeAdded = createAppAsyncThunk( - 'api/nodeAdded', - async ( - arg: { node: NodeAddedArg['requestBody']; sessionId: string }, - _thunkApi - ) => { - const response = await SessionsService.addNode({ - requestBody: arg.node, - sessionId: arg.sessionId, +export const sessionCreated = createAppAsyncThunk< + GraphExecutionState, + SessionCreatedArg, + SessionCreatedThunkConfig +>('api/sessionCreated', async (arg, { rejectWithValue }) => { + try { + const response = await SessionsService.createSession({ + requestBody: arg.graph, }); - - sessionLog.info({ arg, response }, `Node added (${response})`); - return response; + } catch (error) { + return rejectWithValue({ arg, error }); } -); +}); + +type SessionInvokedArg = { sessionId: string }; + +type SessionInvokedThunkConfig = { + rejectValue: { + arg: SessionInvokedArg; + error: unknown; + }; +}; + +const isErrorWithStatus = (error: unknown): error is { status: number } => + isObject(error) && 'status' in error; /** * `SessionsService.invokeSession()` thunk */ -export const sessionInvoked = createAppAsyncThunk( - 'api/sessionInvoked', - async (arg: { sessionId: string }, { rejectWithValue }) => { - const { sessionId } = arg; +export const sessionInvoked = createAppAsyncThunk< + void, + SessionInvokedArg, + SessionInvokedThunkConfig +>('api/sessionInvoked', async (arg, { rejectWithValue }) => { + const { sessionId } = arg; - try { - const response = await SessionsService.invokeSession({ - sessionId, - all: true, - }); - sessionLog.info({ arg, response }, `Session invoked (${sessionId})`); - - return response; - } catch (error) { - const err = error as any; - if (err.status === 403) { - return rejectWithValue(err.body.detail); - } - throw error; + try { + const response = await SessionsService.invokeSession({ + sessionId, + all: true, + }); + return response; + } catch (error) { + if (isErrorWithStatus(error) && error.status === 403) { + return rejectWithValue({ arg, error: (error as any).body.detail }); } + return rejectWithValue({ arg, error }); } -); +}); type SessionCanceledArg = Parameters< (typeof SessionsService)['cancelSessionInvoke'] >[0]; - +type SessionCanceledThunkConfig = { + rejectValue: { + arg: SessionCanceledArg; + error: unknown; + }; +}; /** * `SessionsService.cancelSession()` thunk */ -export const sessionCanceled = createAppAsyncThunk( - 'api/sessionCanceled', - async (arg: SessionCanceledArg, _thunkApi) => { - const { sessionId } = arg; +export const sessionCanceled = createAppAsyncThunk< + void, + SessionCanceledArg, + SessionCanceledThunkConfig +>('api/sessionCanceled', async (arg: SessionCanceledArg, _thunkApi) => { + const { sessionId } = arg; - const response = await SessionsService.cancelSessionInvoke({ - sessionId, - }); + const response = await SessionsService.cancelSessionInvoke({ + sessionId, + }); - sessionLog.info({ arg, response }, `Session canceled (${sessionId})`); - - return response; - } -); + return response; +}); type SessionsListedArg = Parameters< (typeof SessionsService)['listSessions'] diff --git a/pyproject.toml b/pyproject.toml index 7913905aae..38aa71bd0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies = [ "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel~=1.1.5", + "controlnet-aux>=0.0.4", + "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets", "diffusers[torch]~=0.16.1", "dnspython==2.2.1", @@ -54,6 +56,7 @@ dependencies = [ "flaskwebgui==1.0.3", "gfpgan==1.3.8", "huggingface-hub>=0.11.1", + "mediapipe", # needed for "mediapipeface" controlnet model "npyscreen", "numpy<1.24", "omegaconf", diff --git a/scripts/controlnet_legacy_txt2img_example.py b/scripts/controlnet_legacy_txt2img_example.py new file mode 100644 index 0000000000..eb299c9d47 --- /dev/null +++ b/scripts/controlnet_legacy_txt2img_example.py @@ -0,0 +1,54 @@ +import os +import torch +import cv2 +import numpy as np +from PIL import Image + +from diffusers.utils import load_image +from diffusers.models.controlnet import ControlNetModel +from invokeai.backend.generator import Txt2Img +from invokeai.backend.model_management import ModelManager + + +print("loading 'Girl with a Pearl Earring' image") +image = load_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" +) +image.show() + +print("preprocessing image with Canny edge detection") +image_np = np.array(image) +low_threshold = 100 +high_threshold = 200 +canny_np = cv2.Canny(image_np, low_threshold, high_threshold) +canny_image = Image.fromarray(canny_np) +canny_image.show() + +# using invokeai model management for base model +print("loading base model stable-diffusion-1.5") +model_config_path = os.getcwd() + "/../configs/models.yaml" +model_manager = ModelManager(model_config_path) +model = model_manager.get_model('stable-diffusion-1.5') + +print("loading control model lllyasviel/sd-controlnet-canny") +canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", + torch_dtype=torch.float16).to("cuda") + +print("testing Txt2Img() constructor with control_model arg") +txt2img_canny = Txt2Img(model, control_model=canny_controlnet) + +print("testing Txt2Img.generate() with control_image arg") +outputs = txt2img_canny.generate( + prompt="old man", + control_image=canny_image, + control_weight=1.0, + seed=0, + num_steps=30, + precision="float16", +) +generate_output = next(outputs) +out_image = generate_output.image +out_image.show() + + + diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index d4631ec735..9f433aa330 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -35,6 +35,7 @@ def mock_services(): graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore + configuration = None, # type: ignore ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 80ed427485..6e1dde716c 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -33,6 +33,7 @@ def mock_services() -> InvocationServices: graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore + configuration = None, # type: ignore ) @pytest.fixture() diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index e334953d7e..d16d67d815 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -49,7 +49,7 @@ class ImageTestInvocation(BaseInvocation): prompt: str = Field(default = "") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: - return ImageTestInvocationOutput(image=ImageField(image_name=self.id, width=512, height=512, mode="", info={})) + return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) class PromptCollectionTestInvocationOutput(BaseInvocationOutput): type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'