diff --git a/docs/nodes/communityNodes.md b/docs/nodes/communityNodes.md index 3879cdc3c3..f3b8af0425 100644 --- a/docs/nodes/communityNodes.md +++ b/docs/nodes/communityNodes.md @@ -8,7 +8,7 @@ To use a node, add the node to the `nodes` folder found in your InvokeAI install The suggested method is to use `git clone` to clone the repository the node is found in. This allows for easy updates of the node in the future. -If you'd prefer, you can also just download the `.py` file from the linked repository and add it to the `nodes` folder. +If you'd prefer, you can also just download the whole node folder from the linked repository and add it to the `nodes` folder. To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor. @@ -26,6 +26,7 @@ To use a community workflow, download the the `.json` node graph file and load i + [Image Picker](#image-picker) + [Load Video Frame](#load-video-frame) + [Make 3D](#make-3d) + + [Match Histogram](#match-histogram) + [Oobabooga](#oobabooga) + [Prompt Tools](#prompt-tools) + [Remote Image](#remote-image) @@ -208,6 +209,23 @@ This includes 15 Nodes: +-------------------------------- +### Match Histogram + +**Description:** An InvokeAI node to match a histogram from one image to another. This is a bit like the `color correct` node in the main InvokeAI but this works in the YCbCr colourspace and can handle images of different sizes. Also does not require a mask input. +- Option to only transfer luminance channel. +- Option to save output as grayscale + +A good use case for this node is to normalize the colors of an image that has been through the tiled scaling workflow of my XYGrid Nodes. + +See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md + +**Node Link:** https://github.com/skunkworxdark/match_histogram + +**Output Examples** + + + -------------------------------- ### Oobabooga @@ -237,22 +255,30 @@ This node works best with SDXL models, especially as the style can be described -------------------------------- ### Prompt Tools -**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These were written to accompany the PromptsFromFile node and other prompt generation nodes. +**Description:** A set of InvokeAI nodes that add general prompt (string) manipulation tools. Designed to accompany the `Prompts From File` node and other prompt generation nodes. + +1. `Prompt To File` - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option. +2. `PTFields Collect` - Converts image generation fields into a Json format string that can be passed to Prompt to file. +3. `PTFields Expand` - Takes Json string and converts it to individual generation parameters. This can be fed from the Prompts to file node. +4. `Prompt Strength` - Formats prompt with strength like the weighted format of compel +5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend() +6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option + +The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.
+- `Prompt Join` -> `String Join` +- `Prompt Join Three` -> `String Join Three` +- `Prompt Replace` -> `String Replace` +- `Prompt Split Neg` -> `String Split Neg` -1. PromptJoin - Joins to prompts into one. -2. PromptReplace - performs a search and replace on a prompt. With the option of using regex. -3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative. -4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option. -5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file. -6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node. -7. PromptJoinThree - Joins 3 prompt together. -8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel. -9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node. See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md **Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes +**Workflow Examples** + + + -------------------------------- ### Remote Image @@ -339,15 +365,27 @@ Highlights/Midtones/Shadows (with LUT blur enabled): -------------------------------- ### XY Image to Grid and Images to Grids nodes -**Description:** Image to grid nodes and supporting tools. +**Description:** These nodes add the following to InvokeAI: +- Generate grids of images from multiple input images +- Create XY grid images with labels from parameters +- Split images into overlapping tiles for processing (for super-resolution workflows) +- Recombine image tiles into a single output image blending the seams -1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then multiple grids will be created until it runs out of images. -2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporting nodes. See example node setups for more details. +The nodes include: +1. `Images To Grids` - Combine multiple images into a grid of images +2. `XYImage To Grid` - Take X & Y params and creates a labeled image grid. +3. `XYImage Tiles` - Super-resolution (embiggen) style tiled resizing +4. `Image Tot XYImages` - Takes an image and cuts it up into a number of columns and rows. +5. Multiple supporting nodes - Helper nodes for data wrangling and building `XYImage` collections See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md **Node Link:** https://github.com/skunkworxdark/XYGrid_nodes +**Output Examples** + + + -------------------------------- ### Example Node Template diff --git a/docs/nodes/defaultNodes.md b/docs/nodes/defaultNodes.md index ace51163ef..1f490dfe81 100644 --- a/docs/nodes/defaultNodes.md +++ b/docs/nodes/defaultNodes.md @@ -1,104 +1,106 @@ # List of Default Nodes -The table below contains a list of the default nodes shipped with InvokeAI and their descriptions. +The table below contains a list of the default nodes shipped with InvokeAI and +their descriptions. -| Node | Function | -|: ---------------------------------- | :--------------------------------------------------------------------------------------| -|Add Integers | Adds two numbers| -|Boolean Primitive Collection | A collection of boolean primitive values| -|Boolean Primitive | A boolean primitive value| -|Canny Processor | Canny edge detection for ControlNet| -|CLIP Skip | Skip layers in clip text_encoder model.| -|Collect | Collects values into a collection| -|Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image.| -|Color Primitive | A color primitive value| -|Compel Prompt | Parse prompt using compel package to conditioning.| -|Conditioning Primitive Collection | A collection of conditioning tensor primitive values| -|Conditioning Primitive | A conditioning tensor primitive value| -|Content Shuffle Processor | Applies content shuffle processing to image| -|ControlNet | Collects ControlNet info to pass to other nodes| -|Denoise Latents | Denoises noisy latents to decodable images| -|Divide Integers | Divides two numbers| -|Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator| -|[FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting| -|[FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image| -|[FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting| -|Float Math | Perform basic math operations on two floats| -|Float Primitive Collection | A collection of float primitive values| -|Float Primitive | A float primitive value| -|Float Range | Creates a range| -|HED (softedge) Processor | Applies HED edge detection to image| -|Blur Image | Blurs an image| -|Extract Image Channel | Gets a channel from an image.| -|Image Primitive Collection | A collection of image primitive values| -|Integer Math | Perform basic math operations on two integers| -|Convert Image Mode | Converts an image to a different mode.| -|Crop Image | Crops an image to a specified box. The box can be outside of the image.| -|Image Hue Adjustment | Adjusts the Hue of an image.| -|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image| -|Image Primitive | An image primitive value| -|Lerp Image | Linear interpolation of all pixels of an image| -|Offset Image Channel | Add to or subtract from an image color channel by a uniform value.| -|Multiply Image Channel | Multiply or Invert an image color channel by a scalar value.| -|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.| -|Blur NSFW Image | Add blur to NSFW-flagged images| -|Paste Image | Pastes an image into another image.| -|ImageProcessor | Base class for invocations that preprocess images for ControlNet| -|Resize Image | Resizes an image to specific dimensions| -|Round Float | Rounds a float to a specified number of decimal places| -|Float to Integer | Converts a float to an integer. Optionally rounds to an even multiple of a input number.| -|Scale Image | Scales an image by a factor| -|Image to Latents | Encodes an image into latents.| -|Add Invisible Watermark | Add an invisible watermark to an image| -|Solid Color Infill | Infills transparent areas of an image with a solid color| -|PatchMatch Infill | Infills transparent areas of an image using the PatchMatch algorithm| -|Tile Infill | Infills transparent areas of an image with tiles of the image| -|Integer Primitive Collection | A collection of integer primitive values| -|Integer Primitive | An integer primitive value| -|Iterate | Iterates over a list of items| -|Latents Primitive Collection | A collection of latents tensor primitive values| -|Latents Primitive | A latents tensor primitive value| -|Latents to Image | Generates an image from latents.| -|Leres (Depth) Processor | Applies leres processing to image| -|Lineart Anime Processor | Applies line art anime processing to image| -|Lineart Processor | Applies line art processing to image| -|LoRA Loader | Apply selected lora to unet and text_encoder.| -|Main Model Loader | Loads a main model, outputting its submodels.| -|Combine Mask | Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.| -|Mask Edge | Applies an edge mask to an image| -|Mask from Alpha | Extracts the alpha channel of an image as a mask.| -|Mediapipe Face Processor | Applies mediapipe face processing to image| -|Midas (Depth) Processor | Applies Midas depth processing to image| -|MLSD Processor | Applies MLSD processing to image| -|Multiply Integers | Multiplies two numbers| -|Noise | Generates latent noise.| -|Normal BAE Processor | Applies NormalBae processing to image| -|ONNX Latents to Image | Generates an image from latents.| -|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.| -|ONNX Text to Latents | Generates latents from conditionings.| -|ONNX Model Loader | Loads a main model, outputting its submodels.| -|OpenCV Inpaint | Simple inpaint using opencv.| -|Openpose Processor | Applies Openpose processing to image| -|PIDI Processor | Applies PIDI processing to image| -|Prompts from File | Loads prompts from a text file| -|Random Integer | Outputs a single random integer.| -|Random Range | Creates a collection of random numbers| -|Integer Range | Creates a range of numbers from start to stop with step| -|Integer Range of Size | Creates a range from start to start + size with step| -|Resize Latents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.| -|SDXL Compel Prompt | Parse prompt using compel package to conditioning.| -|SDXL LoRA Loader | Apply selected lora to unet and text_encoder.| -|SDXL Main Model Loader | Loads an sdxl base model, outputting its submodels.| -|SDXL Refiner Compel Prompt | Parse prompt using compel package to conditioning.| -|SDXL Refiner Model Loader | Loads an sdxl refiner model, outputting its submodels.| -|Scale Latents | Scales latents by a given factor.| -|Segment Anything Processor | Applies segment anything processing to image| -|Show Image | Displays a provided image, and passes it forward in the pipeline.| -|Step Param Easing | Experimental per-step parameter easing for denoising steps| -|String Primitive Collection | A collection of string primitive values| -|String Primitive | A string primitive value| -|Subtract Integers | Subtracts two numbers| -|Tile Resample Processor | Tile resampler processor| -|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.| -|VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput| -|Zoe (Depth) Processor | Applies Zoe depth processing to image| \ No newline at end of file +| Node | Function | +| :------------------------------------------------------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------- | +| Add Integers | Adds two numbers | +| Boolean Primitive Collection | A collection of boolean primitive values | +| Boolean Primitive | A boolean primitive value | +| Canny Processor | Canny edge detection for ControlNet | +| CenterPadCrop | Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image. | +| CLIP Skip | Skip layers in clip text_encoder model. | +| Collect | Collects values into a collection | +| Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. | +| Color Primitive | A color primitive value | +| Compel Prompt | Parse prompt using compel package to conditioning. | +| Conditioning Primitive Collection | A collection of conditioning tensor primitive values | +| Conditioning Primitive | A conditioning tensor primitive value | +| Content Shuffle Processor | Applies content shuffle processing to image | +| ControlNet | Collects ControlNet info to pass to other nodes | +| Denoise Latents | Denoises noisy latents to decodable images | +| Divide Integers | Divides two numbers | +| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator | +| [FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting | +| [FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image | +| [FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting | +| Float Math | Perform basic math operations on two floats | +| Float Primitive Collection | A collection of float primitive values | +| Float Primitive | A float primitive value | +| Float Range | Creates a range | +| HED (softedge) Processor | Applies HED edge detection to image | +| Blur Image | Blurs an image | +| Extract Image Channel | Gets a channel from an image. | +| Image Primitive Collection | A collection of image primitive values | +| Integer Math | Perform basic math operations on two integers | +| Convert Image Mode | Converts an image to a different mode. | +| Crop Image | Crops an image to a specified box. The box can be outside of the image. | +| Image Hue Adjustment | Adjusts the Hue of an image. | +| Inverse Lerp Image | Inverse linear interpolation of all pixels of an image | +| Image Primitive | An image primitive value | +| Lerp Image | Linear interpolation of all pixels of an image | +| Offset Image Channel | Add to or subtract from an image color channel by a uniform value. | +| Multiply Image Channel | Multiply or Invert an image color channel by a scalar value. | +| Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`. | +| Blur NSFW Image | Add blur to NSFW-flagged images | +| Paste Image | Pastes an image into another image. | +| ImageProcessor | Base class for invocations that preprocess images for ControlNet | +| Resize Image | Resizes an image to specific dimensions | +| Round Float | Rounds a float to a specified number of decimal places | +| Float to Integer | Converts a float to an integer. Optionally rounds to an even multiple of a input number. | +| Scale Image | Scales an image by a factor | +| Image to Latents | Encodes an image into latents. | +| Add Invisible Watermark | Add an invisible watermark to an image | +| Solid Color Infill | Infills transparent areas of an image with a solid color | +| PatchMatch Infill | Infills transparent areas of an image using the PatchMatch algorithm | +| Tile Infill | Infills transparent areas of an image with tiles of the image | +| Integer Primitive Collection | A collection of integer primitive values | +| Integer Primitive | An integer primitive value | +| Iterate | Iterates over a list of items | +| Latents Primitive Collection | A collection of latents tensor primitive values | +| Latents Primitive | A latents tensor primitive value | +| Latents to Image | Generates an image from latents. | +| Leres (Depth) Processor | Applies leres processing to image | +| Lineart Anime Processor | Applies line art anime processing to image | +| Lineart Processor | Applies line art processing to image | +| LoRA Loader | Apply selected lora to unet and text_encoder. | +| Main Model Loader | Loads a main model, outputting its submodels. | +| Combine Mask | Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. | +| Mask Edge | Applies an edge mask to an image | +| Mask from Alpha | Extracts the alpha channel of an image as a mask. | +| Mediapipe Face Processor | Applies mediapipe face processing to image | +| Midas (Depth) Processor | Applies Midas depth processing to image | +| MLSD Processor | Applies MLSD processing to image | +| Multiply Integers | Multiplies two numbers | +| Noise | Generates latent noise. | +| Normal BAE Processor | Applies NormalBae processing to image | +| ONNX Latents to Image | Generates an image from latents. | +| ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in **init** to receive providers. | +| ONNX Text to Latents | Generates latents from conditionings. | +| ONNX Model Loader | Loads a main model, outputting its submodels. | +| OpenCV Inpaint | Simple inpaint using opencv. | +| Openpose Processor | Applies Openpose processing to image | +| PIDI Processor | Applies PIDI processing to image | +| Prompts from File | Loads prompts from a text file | +| Random Integer | Outputs a single random integer. | +| Random Range | Creates a collection of random numbers | +| Integer Range | Creates a range of numbers from start to stop with step | +| Integer Range of Size | Creates a range from start to start + size with step | +| Resize Latents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. | +| SDXL Compel Prompt | Parse prompt using compel package to conditioning. | +| SDXL LoRA Loader | Apply selected lora to unet and text_encoder. | +| SDXL Main Model Loader | Loads an sdxl base model, outputting its submodels. | +| SDXL Refiner Compel Prompt | Parse prompt using compel package to conditioning. | +| SDXL Refiner Model Loader | Loads an sdxl refiner model, outputting its submodels. | +| Scale Latents | Scales latents by a given factor. | +| Segment Anything Processor | Applies segment anything processing to image | +| Show Image | Displays a provided image, and passes it forward in the pipeline. | +| Step Param Easing | Experimental per-step parameter easing for denoising steps | +| String Primitive Collection | A collection of string primitive values | +| String Primitive | A string primitive value | +| Subtract Integers | Subtracts two numbers | +| Tile Resample Processor | Tile resampler processor | +| Upscale (RealESRGAN) | Upscales an image using RealESRGAN. | +| VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput | +| Zoe (Depth) Processor | Applies Zoe depth processing to image | diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 780c965a3f..9f37aca13f 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -119,6 +119,61 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): category="image", version="1.2.0", ) +class CenterPadCropInvocation(BaseInvocation): + """Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image.""" + + image: ImageField = InputField(description="The image to crop") + left: int = InputField( + default=0, + description="Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards)", + ) + right: int = InputField( + default=0, + description="Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards)", + ) + top: int = InputField( + default=0, + description="Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards)", + ) + bottom: int = InputField( + default=0, + description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image(self.image.image_name) + + # Calculate and create new image dimensions + new_width = image.width + self.right + self.left + new_height = image.height + self.top + self.bottom + image_crop = Image.new(mode="RGBA", size=(new_width, new_height), color=(0, 0, 0, 0)) + + # Paste new image onto input + image_crop.paste(image, (self.left, self.top)) + + image_dto = context.services.images.create( + image=image_crop, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + + +@invocation( + invocation_type="img_pad_crop", + title="Center Pad or Crop Image", + category="image", + tags=["image", "pad", "crop"], + version="1.0.0", +) class ImagePasteInvocation(BaseInvocation, WithMetadata): """Pastes an image into another image.""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index eb0640f389..796ef82dcd 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -78,6 +78,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +# factor is hard-coded to a literal '8' rather than using this constant. +# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +LATENT_SCALE_FACTOR = 8 + @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): @@ -214,7 +220,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.4.0", + version="1.5.0", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -272,6 +278,9 @@ class DenoiseLatentsInvocation(BaseInvocation): input=Input.Connection, ui_order=7, ) + cfg_rescale_multiplier: float = InputField( + default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier + ) latents: Optional[LatentsField] = InputField( default=None, description=FieldDescriptions.latents, @@ -331,6 +340,7 @@ class DenoiseLatentsInvocation(BaseInvocation): unconditioned_embeddings=uc, text_embeddings=c, guidance_scale=self.cfg_scale, + guidance_rescale_multiplier=self.cfg_rescale_multiplier, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( threshold=0.0, # threshold, @@ -389,9 +399,9 @@ class DenoiseLatentsInvocation(BaseInvocation): exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents - control_height_resize = latents_shape[2] * 8 - control_width_resize = latents_shape[3] * 8 + # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. + control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR + control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR if control_input is None: control_list = None elif isinstance(control_input, list) and len(control_input) == 0: @@ -904,12 +914,12 @@ class ResizeLatentsInvocation(BaseInvocation): ) width: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) height: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) @@ -923,7 +933,7 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents.to(device), - size=(self.height // 8, self.width // 8), + size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) @@ -1161,3 +1171,60 @@ class BlendLatentsInvocation(BaseInvocation): # context.services.latents.set(name, resized_latents) context.services.latents.save(name, blended_latents) return build_latents_output(latents_name=name, latents=blended_latents) + + +# The Crop Latents node was copied from @skunkworxdark's implementation here: +# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80 +@invocation( + "crop_latents", + title="Crop Latents", + tags=["latents", "crop"], + category="latents", + version="1.0.0", +) +# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. +# Currently, if the class names conflict then 'GET /openapi.json' fails. +class CropLatentsCoreInvocation(BaseInvocation): + """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be + divisible by the latent scale factor of 8. + """ + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + x: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + y: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + width: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + height: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + x1 = self.x // LATENT_SCALE_FACTOR + y1 = self.y // LATENT_SCALE_FACTOR + x2 = x1 + (self.width // LATENT_SCALE_FACTOR) + y2 = y1 + (self.height // LATENT_SCALE_FACTOR) + + cropped_latents = latents[..., y1:y2, x1:x2] + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.save(name, cropped_latents) + + return build_latents_output(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index d837e6297f..da243966de 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -112,7 +112,7 @@ GENERATION_MODES = Literal[ ] -@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1") +@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.2.0") class CoreMetadataInvocation(BaseInvocation): """Collects core generation metadata into a MetadataField""" @@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation): seed: Optional[int] = InputField(default=None, description="The seed used for noise generation") rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation") cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter") + cfg_rescale_multiplier: Optional[float] = InputField( + default=None, description=FieldDescriptions.cfg_rescale_multiplier + ) steps: Optional[int] = InputField(default=None, description="The number of steps used for inference") scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference") seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis") diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index cb43a52447..4778d98077 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -44,7 +44,7 @@ class DynamicPromptInvocation(BaseInvocation): title="Prompts from File", tags=["prompt", "file"], category="prompt", - version="1.0.0", + version="1.0.1", ) class PromptsFromFileInvocation(BaseInvocation): """Loads prompts from a text file""" @@ -82,7 +82,7 @@ class PromptsFromFileInvocation(BaseInvocation): end_line = start_line + max_prompts if max_prompts <= 0: end_line = np.iinfo(np.int32).max - with open(file_path) as f: + with open(file_path, encoding="utf-8") as f: for i, line in enumerate(f): if i >= start_line and i < end_line: prompts.append((pre_prompt or "") + line.strip() + (post_prompt or "")) diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py new file mode 100644 index 0000000000..e59a0530ee --- /dev/null +++ b/invokeai/app/invocations/tiles.py @@ -0,0 +1,180 @@ +import numpy as np +from PIL import Image +from pydantic import BaseModel + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InputField, + InvocationContext, + OutputField, + WithMetadata, + invocation, + invocation_output, +) +from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.utils import Tile + + +class TileWithImage(BaseModel): + tile: Tile + image: ImageField + + +@invocation_output("calculate_image_tiles_output") +class CalculateImageTilesOutput(BaseInvocationOutput): + tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.") + + +@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0") +class CalculateImageTilesInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.") + tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") + overlap: int = InputField( + ge=0, + default=128, + description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", + ) + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_with_overlap( + image_height=self.image_height, + image_width=self.image_width, + tile_height=self.tile_height, + tile_width=self.tile_width, + overlap=self.overlap, + ) + return CalculateImageTilesOutput(tiles=tiles) + + +@invocation_output("tile_to_properties_output") +class TileToPropertiesOutput(BaseInvocationOutput): + coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.") + coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.") + coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.") + coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.") + + # HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this + # object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design + # principles. These fields are included, because 1) they are often useful in tiled workflows, and 2) they are + # difficult to calculate in a workflow (even though it's just a couple of subtraction nodes the graph gets + # surprisingly complicated). + width: int = OutputField(description="The width of the tile. Equal to coords_right - coords_left.") + height: int = OutputField(description="The height of the tile. Equal to coords_bottom - coords_top.") + + overlap_top: int = OutputField(description="Overlap between this tile and its top neighbor.") + overlap_bottom: int = OutputField(description="Overlap between this tile and its bottom neighbor.") + overlap_left: int = OutputField(description="Overlap between this tile and its left neighbor.") + overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.") + + +@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0") +class TileToPropertiesInvocation(BaseInvocation): + """Split a Tile into its individual properties.""" + + tile: Tile = InputField(description="The tile to split into properties.") + + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + return TileToPropertiesOutput( + coords_left=self.tile.coords.left, + coords_right=self.tile.coords.right, + coords_top=self.tile.coords.top, + coords_bottom=self.tile.coords.bottom, + width=self.tile.coords.right - self.tile.coords.left, + height=self.tile.coords.bottom - self.tile.coords.top, + overlap_top=self.tile.overlap.top, + overlap_bottom=self.tile.overlap.bottom, + overlap_left=self.tile.overlap.left, + overlap_right=self.tile.overlap.right, + ) + + +@invocation_output("pair_tile_image_output") +class PairTileImageOutput(BaseInvocationOutput): + tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.") + + +@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0") +class PairTileImageInvocation(BaseInvocation): + """Pair an image with its tile properties.""" + + # TODO(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve + # order. Can this be fixed? + + image: ImageField = InputField(description="The tile image.") + tile: Tile = InputField(description="The tile properties.") + + def invoke(self, context: InvocationContext) -> PairTileImageOutput: + return PairTileImageOutput( + tile_with_image=TileWithImage( + tile=self.tile, + image=self.image, + ) + ) + + +@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0") +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): + """Merge multiple tile images into a single image.""" + + # Inputs + tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.") + blend_amount: int = InputField( + ge=0, + description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + images = [twi.image for twi in self.tiles_with_images] + tiles = [twi.tile for twi in self.tiles_with_images] + + # Infer the output image dimensions from the max/min tile limits. + height = 0 + width = 0 + for tile in tiles: + height = max(height, tile.coords.bottom) + width = max(width, tile.coords.right) + + # Get all tile images for processing. + # TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly + # existed in memory at an earlier point in the graph. + tile_np_images: list[np.ndarray] = [] + for image in images: + pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = pil_image.convert("RGB") + tile_np_images.append(np.array(pil_image)) + + # Prepare the output image buffer. + # Check the first tile to determine how many image channels are expected in the output. + channels = tile_np_images[0].shape[-1] + dtype = tile_np_images[0].dtype + np_image = np.zeros(shape=(height, width, channels), dtype=dtype) + + merge_tiles_with_linear_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) + pil_image = Image.fromarray(np_image) + + image_dto = context.services.images.create( + image=pil_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + metadata=self.metadata, + workflow=context.workflow, + ) + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) diff --git a/invokeai/app/shared/fields.py b/invokeai/app/shared/fields.py index dd9cbb7b82..3e841ffbf2 100644 --- a/invokeai/app/shared/fields.py +++ b/invokeai/app/shared/fields.py @@ -2,6 +2,7 @@ class FieldDescriptions: denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps" cfg_scale = "Classifier-Free Guidance scale" + cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" scheduler = "Scheduler to use during inference" positive_cond = "Positive conditioning tensor" negative_cond = "Negative conditioning tensor" diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 826112156d..9176bf1f49 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module): return clip_extra_context_tokens +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim), + ) + + @classmethod + def from_state_dict(cls, state_dict: dict[torch.Tensor]): + """Initialize an MLPProjModel from a state_dict. + + The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. + + Args: + state_dict (dict[torch.Tensor]): The state_dict of model weights. + + Returns: + MLPProjModel + """ + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + + model = cls(cross_attention_dim, clip_embeddings_dim) + + model.load_state_dict(state_dict) + return model + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + class IPAdapter: """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" @@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter): return image_prompt_embeds, uncond_image_prompt_embeds +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter Plus with full features.""" + + def _init_image_proj_model(self, state_dict: dict[torch.Tensor]): + return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype) + + class IPAdapterPlusXL(IPAdapterPlus): """IP-Adapter Plus for SDXL.""" @@ -149,11 +194,9 @@ def build_ip_adapter( ) -> Union[IPAdapter, IPAdapterPlus]: state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") - # Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it - # contains. - is_plus = "proj.weight" not in state_dict["image_proj"] - - if is_plus: + if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). + return IPAdapter(state_dict, device=device, dtype=dtype) + elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler). cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] if cross_attention_dim == 768: # SD1 IP-Adapter Plus @@ -163,5 +206,7 @@ def build_ip_adapter( return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) else: raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") + elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel). + return IPAdapterFull(state_dict, device=device, dtype=dtype) else: - return IPAdapter(state_dict, device=device, dtype=dtype) + raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.") diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 1353e804a7..ae0cc17203 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if isinstance(guidance_scale, list): guidance_scale = guidance_scale[step_index] - noise_pred = self.invokeai_diffuser._combine( - uc_noise_pred, - c_noise_pred, - guidance_scale, - ) + noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale) + guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier + if guidance_rescale_multiplier > 0: + noise_pred = self._rescale_cfg( + noise_pred, + c_noise_pred, + guidance_rescale_multiplier, + ) # compute the previous noisy sample x_t -> x_t-1 step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) @@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return step_output + @staticmethod + def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7): + """Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf.""" + ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True) + ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True) + + x_rescaled = total_noise_pred * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred + return x_final + def _unet_forward( self, latents, diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 6a63c225fc..3e38f9f78d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -67,13 +67,17 @@ class IPAdapterConditioningInfo: class ConditioningData: unconditioned_embeddings: BasicConditioningInfo text_embeddings: BasicConditioningInfo - guidance_scale: Union[float, List[float]] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. """ + guidance_scale: Union[float, List[float]] + """ for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 . + ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) + """ + guidance_rescale_multiplier: float = 0 extra: Optional[ExtraConditioningInfo] = None scheduler_args: dict[str, Any] = field(default_factory=dict) """ diff --git a/invokeai/backend/tiles/__init__.py b/invokeai/backend/tiles/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py new file mode 100644 index 0000000000..3a678d825e --- /dev/null +++ b/invokeai/backend/tiles/tiles.py @@ -0,0 +1,201 @@ +import math +from typing import Union + +import numpy as np + +from invokeai.backend.tiles.utils import TBLR, Tile, paste + + +def calc_tiles_with_overlap( + image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0 +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + tile_height (int): The tile height in px. All tiles will have this height. + tile_width (int): The tile width in px. All tiles will have this width. + overlap (int, optional): The target overlap between adjacent tiles. If the tiles do not evenly cover the image + shape, then the last row/column of tiles will overlap more than this. Defaults to 0. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + assert image_height >= tile_height + assert image_width >= tile_width + assert overlap < tile_height + assert overlap < tile_width + + non_overlap_per_tile_height = tile_height - overlap + non_overlap_per_tile_width = tile_width - overlap + + num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height) + num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + tile = Tile( + coords=TBLR( + top=tile_idx_y * non_overlap_per_tile_height, + bottom=tile_idx_y * non_overlap_per_tile_height + tile_height, + left=tile_idx_x * non_overlap_per_tile_width, + right=tile_idx_x * non_overlap_per_tile_width + tile_width, + ), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + + if tile.coords.bottom > image_height: + # If this tile would go off the bottom of the image, shift it so that it is aligned with the bottom + # of the image. + tile.coords.bottom = image_height + tile.coords.top = image_height - tile_height + + if tile.coords.right > image_width: + # If this tile would go off the right edge of the image, shift it so that it is aligned with the + # right edge of the image. + tile.coords.right = image_width + tile.coords.left = image_width - tile_width + + tiles.append(tile) + + def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: + if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: + return None + return tiles[idx_y * num_tiles_x + idx_x] + + # Iterate over tiles again and calculate overlaps. + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) + top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) + left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + + assert cur_tile is not None + + # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. + if top_neighbor_tile is not None: + cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) + top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + + # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. + if left_neighbor_tile is not None: + cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) + left_neighbor_tile.overlap.right = cur_tile.overlap.left + + return tiles + + +def merge_tiles_with_linear_blending( + dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int +): + """Merge a set of image tiles into `dst_image` with linear blending between the tiles. + + We expect every tile edge to either: + 1) have an overlap of 0, because it is aligned with the image edge, or + 2) have an overlap >= blend_amount. + If neither of these conditions are satisfied, we raise an exception. + + The linear blending is centered at the halfway point of the overlap between adjacent tiles. + + Args: + dst_image (np.ndarray): The destination image. Shape: (H, W, C). + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + tile_images (list[np.ndarray]): The tile images to merge into `dst_image`. + blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_images = list(zip(tiles, tile_images, strict=True)) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top) + + # Organize tiles into rows. + tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = [] + cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = [] + first_tile_in_cur_row, _ = tiles_and_images[0] + for tile_and_image in tiles_and_images: + tile, _ = tile_and_image + if not ( + tile.coords.top == first_tile_in_cur_row.coords.top + and tile.coords.bottom == first_tile_in_cur_row.coords.bottom + ): + # Store the previous row, and start a new one. + tile_and_image_rows.append(cur_tile_and_image_row) + cur_tile_and_image_row = [] + first_tile_in_cur_row, _ = tile_and_image + + cur_tile_and_image_row.append(tile_and_image) + tile_and_image_rows.append(cur_tile_and_image_row) + + # Prepare 1D linear gradients for blending. + gradient_left_x = np.linspace(start=0.0, stop=1.0, num=blend_amount) + gradient_top_y = np.linspace(start=0.0, stop=1.0, num=blend_amount) + # Convert shape: (blend_amount, ) -> (blend_amount, 1). The extra dimension enables the gradient to be applied + # to a 2D image via broadcasting. Note that no additional dimension is needed on gradient_left_x for + # broadcasting to work correctly. + gradient_top_y = np.expand_dims(gradient_top_y, axis=1) + + for tile_and_image_row in tile_and_image_rows: + first_tile_in_row, _ = tile_and_image_row[0] + row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top + row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype) + + # Blend the tiles in the row horizontally. + for tile, tile_image in tile_and_image_row: + # We expect the tiles to be ordered left-to-right. For each tile, we construct a mask that applies linear + # blending to the left of the current tile. The inverse linear blending is automatically applied to the + # right of the tiles that have already been pasted by the paste(...) operation. + tile_height, tile_width, _ = tile_image.shape + mask = np.ones(shape=(tile_height, tile_width), dtype=np.float64) + + # Left blending: + if tile.overlap.left > 0: + assert tile.overlap.left >= blend_amount + # Center the blending gradient in the middle of the overlap. + blend_start_left = tile.overlap.left // 2 - blend_amount // 2 + # The region left of the blending region is masked completely. + mask[:, :blend_start_left] = 0.0 + # Apply the blend gradient to the mask. + mask[:, blend_start_left : blend_start_left + blend_amount] = gradient_left_x + # For visual debugging: + # tile_image[:, blend_start_left : blend_start_left + blend_amount] = 0 + + paste( + dst_image=row_image, + src_image=tile_image, + box=TBLR( + top=0, bottom=tile.coords.bottom - tile.coords.top, left=tile.coords.left, right=tile.coords.right + ), + mask=mask, + ) + + # Blend the row into the dst_image vertically. + # We construct a mask that applies linear blending to the top of the current row. The inverse linear blending is + # automatically applied to the bottom of the tiles that have already been pasted by the paste(...) operation. + mask = np.ones(shape=(row_image.shape[0], row_image.shape[1]), dtype=np.float64) + # Top blending: + # (See comments under 'Left blending' for an explanation of the logic.) + # We assume that the entire row has the same vertical overlaps as the first_tile_in_row. + if first_tile_in_row.overlap.top > 0: + assert first_tile_in_row.overlap.top >= blend_amount + blend_start_top = first_tile_in_row.overlap.top // 2 - blend_amount // 2 + mask[:blend_start_top, :] = 0.0 + mask[blend_start_top : blend_start_top + blend_amount, :] = gradient_top_y + # For visual debugging: + # row_image[blend_start_top : blend_start_top + blend_amount, :] = 0 + paste( + dst_image=dst_image, + src_image=row_image, + box=TBLR( + top=first_tile_in_row.coords.top, + bottom=first_tile_in_row.coords.bottom, + left=0, + right=row_image.shape[1], + ), + mask=mask, + ) diff --git a/invokeai/backend/tiles/utils.py b/invokeai/backend/tiles/utils.py new file mode 100644 index 0000000000..4ad40ffa35 --- /dev/null +++ b/invokeai/backend/tiles/utils.py @@ -0,0 +1,47 @@ +from typing import Optional + +import numpy as np +from pydantic import BaseModel, Field + + +class TBLR(BaseModel): + top: int + bottom: int + left: int + right: int + + def __eq__(self, other): + return ( + self.top == other.top + and self.bottom == other.bottom + and self.left == other.left + and self.right == other.right + ) + + +class Tile(BaseModel): + coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.") + overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.") + + def __eq__(self, other): + return self.coords == other.coords and self.overlap == other.overlap + + +def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None): + """Paste a source image into a destination image. + + Args: + dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C). + src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'. + box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted. + mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'. + Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to + `src * mask + dst * (1 - mask)`. + """ + + if mask is None: + dst_image[box.top : box.bottom, box.left : box.right] = src_image + else: + mask = np.expand_dims(mask, -1) + dst_image_box = dst_image[box.top : box.bottom, box.left : box.right] + dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask) diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 6f160bae46..6a6b79c3b7 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -75,6 +75,7 @@ "framer-motion": "^10.16.4", "i18next": "^23.6.0", "i18next-http-backend": "^2.3.1", + "idb-keyval": "^6.2.1", "konva": "^9.2.3", "lodash-es": "^4.17.21", "nanostores": "^0.9.4", diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json index 72809cc19d..b67663d6d2 100644 --- a/invokeai/frontend/web/public/locales/de.json +++ b/invokeai/frontend/web/public/locales/de.json @@ -803,8 +803,7 @@ "canny": "Canny", "hedDescription": "Ganzheitlich verschachtelte Kantenerkennung", "scribble": "Scribble", - "maxFaces": "Maximal Anzahl Gesichter", - "unstarImage": "Markierung aufheben" + "maxFaces": "Maximal Anzahl Gesichter" }, "queue": { "status": "Status", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index e734ae9e08..051528c100 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -252,7 +252,6 @@ "setControlImageDimensions": "Set Control Image Dimensions To W/H", "showAdvanced": "Show Advanced", "toggleControlNet": "Toggle this ControlNet", - "unstarImage": "Unstar Image", "w": "W", "weight": "Weight", "enableIPAdapter": "Enable IP Adapter", @@ -387,6 +386,8 @@ "showGenerations": "Show Generations", "showUploads": "Show Uploads", "singleColumnLayout": "Single Column Layout", + "starImage": "Star Image", + "unstarImage": "Unstar Image", "unableToLoad": "Unable to load Gallery", "uploads": "Uploads", "deleteSelection": "Delete Selection", @@ -608,6 +609,7 @@ }, "metadata": { "cfgScale": "CFG scale", + "cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)", "createdBy": "Created By", "fit": "Image to image fit", "generationMode": "Generation Mode", @@ -986,6 +988,7 @@ "unsupportedAnyOfLength": "too many union members ({{count}})", "unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}", "unableToParseFieldType": "unable to parse field type", + "unableToExtractEnumOptions": "unable to extract enum options", "uNetField": "UNet", "uNetFieldDescription": "UNet submodel.", "unhandledInputProperty": "Unhandled input property", @@ -1041,6 +1044,8 @@ "setType": "Set cancel type" }, "cfgScale": "CFG Scale", + "cfgRescaleMultiplier": "CFG Rescale Multiplier", + "cfgRescale": "CFG Rescale", "clipSkip": "CLIP Skip", "clipSkipWithLayerCount": "CLIP Skip {{layerCount}}", "closeViewer": "Close Viewer", @@ -1482,6 +1487,12 @@ "Controls how much your prompt influences the generation process." ] }, + "paramCFGRescaleMultiplier": { + "heading": "CFG Rescale Multiplier", + "paragraphs": [ + "Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7." + ] + }, "paramDenoisingStrength": { "heading": "Denoising Strength", "paragraphs": [ diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh_CN.json index 03838520d3..24105f2b40 100644 --- a/invokeai/frontend/web/public/locales/zh_CN.json +++ b/invokeai/frontend/web/public/locales/zh_CN.json @@ -1137,8 +1137,7 @@ "openPose": "Openpose", "controlAdapter_other": "Control Adapters", "lineartAnime": "Lineart Anime", - "canny": "Canny", - "unstarImage": "取消收藏图像" + "canny": "Canny" }, "queue": { "status": "状态", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 63533aee0d..73bd92ffab 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -21,6 +21,7 @@ import GlobalHotkeys from './GlobalHotkeys'; import PreselectedImage from './PreselectedImage'; import Toaster from './Toaster'; import { useSocketIO } from 'app/hooks/useSocketIO'; +import { useClearStorage } from 'common/hooks/useClearStorage'; const DEFAULT_CONFIG = {}; @@ -36,15 +37,16 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { const language = useAppSelector(languageSelector); const logger = useLogger('system'); const dispatch = useAppDispatch(); + const clearStorage = useClearStorage(); // singleton! useSocketIO(); const handleReset = useCallback(() => { - localStorage.clear(); + clearStorage(); location.reload(); return false; - }, []); + }, [clearStorage]); useEffect(() => { i18n.changeLanguage(language); diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 459ac65635..b190a36f06 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -7,21 +7,23 @@ import { $headerComponent } from 'app/store/nanostores/headerComponent'; import { $isDebugging } from 'app/store/nanostores/isDebugging'; import { $projectId } from 'app/store/nanostores/projectId'; import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId'; -import { store } from 'app/store/store'; +import { $store } from 'app/store/nanostores/store'; +import { createStore } from 'app/store/store'; import { PartialAppConfig } from 'app/types/invokeai'; +import Loading from 'common/components/Loading/Loading'; +import AppDndContext from 'features/dnd/components/AppDndContext'; +import 'i18n'; import React, { PropsWithChildren, ReactNode, lazy, memo, useEffect, + useMemo, } from 'react'; import { Provider } from 'react-redux'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { ManagerOptions, SocketOptions } from 'socket.io-client'; -import Loading from 'common/components/Loading/Loading'; -import AppDndContext from 'features/dnd/components/AppDndContext'; -import 'i18n'; const App = lazy(() => import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -137,6 +139,14 @@ const InvokeAIUI = ({ }; }, [isDebugging]); + const store = useMemo(() => { + return createStore(projectId); + }, [projectId]); + + useEffect(() => { + $store.set(store); + }, [store]); + return ( diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index a9d56a7f16..ba0aaa5823 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -9,9 +9,9 @@ import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme'; import '@fontsource-variable/inter'; import { MantineProvider } from '@mantine/core'; +import { useMantineTheme } from 'mantine-theme/theme'; import 'overlayscrollbars/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css'; -import { useMantineTheme } from 'mantine-theme/theme'; type ThemeLocaleProviderProps = { children: ReactNode; diff --git a/invokeai/frontend/web/src/app/store/constants.ts b/invokeai/frontend/web/src/app/store/constants.ts index 6d48762bef..c2f3a5e10b 100644 --- a/invokeai/frontend/web/src/app/store/constants.ts +++ b/invokeai/frontend/web/src/app/store/constants.ts @@ -1,8 +1 @@ -export const LOCALSTORAGE_KEYS = [ - 'chakra-ui-color-mode', - 'i18nextLng', - 'ROARR_FILTER', - 'ROARR_LOG', -]; - -export const LOCALSTORAGE_PREFIX = '@@invokeai-'; +export const STORAGE_PREFIX = '@@invokeai-'; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index d9bc7b085d..0e3634468b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -23,16 +23,16 @@ import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; import dynamicMiddlewares from 'redux-dynamic-middlewares'; -import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import { Driver, rememberEnhancer, rememberReducer } from 'redux-remember'; import { api } from 'services/api'; -import { LOCALSTORAGE_PREFIX } from './constants'; +import { STORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { $store } from './nanostores/store'; +import { createStore as createIDBKeyValStore, get, set } from 'idb-keyval'; const allReducers = { canvas: canvasReducer, @@ -74,57 +74,70 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'modelmanager', ]; -export const store = configureStore({ - reducer: rememberedRootReducer, - enhancers: (existingEnhancers) => { - return existingEnhancers - .concat( - rememberEnhancer(window.localStorage, rememberedKeys, { - persistDebounce: 300, - serialize, - unserialize, - prefix: LOCALSTORAGE_PREFIX, - }) - ) - .concat(autoBatchEnhancer()); - }, - middleware: (getDefaultMiddleware) => - getDefaultMiddleware({ - serializableCheck: false, - immutableCheck: false, - }) - .concat(api.middleware) - .concat(dynamicMiddlewares) - .prepend(listenerMiddleware.middleware), - devTools: { - actionSanitizer, - stateSanitizer, - trace: true, - predicate: (state, action) => { - // TODO: hook up to the log level param in system slice - // manually type state, cannot type the arg - // const typedState = state as ReturnType; +// Create a custom idb-keyval store (just needed to customize the name) +export const idbKeyValStore = createIDBKeyValStore('invoke', 'invoke-store'); - // TODO: doing this breaks the rtk query devtools, commenting out for now - // if (action.type.startsWith('api/')) { - // // don't log api actions, with manual cache updates they are extremely noisy - // return false; - // } +// Create redux-remember driver, wrapping idb-keyval +const idbKeyValDriver: Driver = { + getItem: (key) => get(key, idbKeyValStore), + setItem: (key, value) => set(key, value, idbKeyValStore), +}; - if (actionsDenylist.includes(action.type)) { - // don't log other noisy actions - return false; - } - - return true; +export const createStore = (uniqueStoreKey?: string) => + configureStore({ + reducer: rememberedRootReducer, + enhancers: (existingEnhancers) => { + return existingEnhancers + .concat( + rememberEnhancer(idbKeyValDriver, rememberedKeys, { + persistDebounce: 300, + serialize, + unserialize, + prefix: uniqueStoreKey + ? `${STORAGE_PREFIX}${uniqueStoreKey}-` + : STORAGE_PREFIX, + }) + ) + .concat(autoBatchEnhancer()); }, - }, -}); + middleware: (getDefaultMiddleware) => + getDefaultMiddleware({ + serializableCheck: false, + immutableCheck: false, + }) + .concat(api.middleware) + .concat(dynamicMiddlewares) + .prepend(listenerMiddleware.middleware), + devTools: { + actionSanitizer, + stateSanitizer, + trace: true, + predicate: (state, action) => { + // TODO: hook up to the log level param in system slice + // manually type state, cannot type the arg + // const typedState = state as ReturnType; -export type AppGetState = typeof store.getState; -export type RootState = ReturnType; + // TODO: doing this breaks the rtk query devtools, commenting out for now + // if (action.type.startsWith('api/')) { + // // don't log api actions, with manual cache updates they are extremely noisy + // return false; + // } + + if (actionsDenylist.includes(action.type)) { + // don't log other noisy actions + return false; + } + + return true; + }, + }, + }); + +export type AppGetState = ReturnType< + ReturnType['getState'] +>; +export type RootState = ReturnType['getState']>; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type AppThunkDispatch = ThunkDispatch; -export type AppDispatch = typeof store.dispatch; +export type AppDispatch = ReturnType['dispatch']; export const stateSelector = (state: RootState) => state; -$store.set(store); diff --git a/invokeai/frontend/web/src/common/components/IAIInformationalPopover/constants.ts b/invokeai/frontend/web/src/common/components/IAIInformationalPopover/constants.ts index 197f5f4068..8960399b48 100644 --- a/invokeai/frontend/web/src/common/components/IAIInformationalPopover/constants.ts +++ b/invokeai/frontend/web/src/common/components/IAIInformationalPopover/constants.ts @@ -25,6 +25,7 @@ export type Feature = | 'lora' | 'noiseUseCPU' | 'paramCFGScale' + | 'paramCFGRescaleMultiplier' | 'paramDenoisingStrength' | 'paramIterations' | 'paramModel' diff --git a/invokeai/frontend/web/src/common/hooks/useClearStorage.ts b/invokeai/frontend/web/src/common/hooks/useClearStorage.ts new file mode 100644 index 0000000000..0ab4936d72 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useClearStorage.ts @@ -0,0 +1,12 @@ +import { idbKeyValStore } from 'app/store/store'; +import { clear } from 'idb-keyval'; +import { useCallback } from 'react'; + +export const useClearStorage = () => { + const clearStorage = useCallback(() => { + clear(idbKeyValStore); + localStorage.clear(); + }, []); + + return clearStorage; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterImagePreview.tsx index e12abf4830..b3b584d07e 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterImagePreview.tsx @@ -5,14 +5,19 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; +import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; +import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage'; +import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage'; +import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType'; +import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'features/dnd/types'; import { setHeight, setWidth } from 'features/parameters/store/generationSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { memo, useCallback, useMemo, useState } from 'react'; +import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa'; import { @@ -22,11 +27,6 @@ import { useRemoveImageFromBoardMutation, } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; -import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; -import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage'; -import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage'; -import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType'; type Props = { id: string; @@ -35,13 +35,15 @@ type Props = { const selector = createSelector( stateSelector, - ({ controlAdapters, gallery }) => { + ({ controlAdapters, gallery, system }) => { const { pendingControlImages } = controlAdapters; const { autoAddBoardId } = gallery; + const { isConnected } = system; return { pendingControlImages, autoAddBoardId, + isConnected, }; }, defaultSelectorOptions @@ -55,18 +57,19 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { pendingControlImages, autoAddBoardId } = useAppSelector(selector); + const { pendingControlImages, autoAddBoardId, isConnected } = + useAppSelector(selector); const activeTabName = useAppSelector(activeTabNameSelector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); - const { currentData: controlImage } = useGetImageDTOQuery( - controlImageName ?? skipToken - ); + const { currentData: controlImage, isError: isErrorControlImage } = + useGetImageDTOQuery(controlImageName ?? skipToken); - const { currentData: processedControlImage } = useGetImageDTOQuery( - processedControlImageName ?? skipToken - ); + const { + currentData: processedControlImage, + isError: isErrorProcessedControlImage, + } = useGetImageDTOQuery(processedControlImageName ?? skipToken); const [changeIsIntermediate] = useChangeImageIsIntermediateMutation(); const [addToBoard] = useAddImageToBoardMutation(); @@ -158,6 +161,17 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => { !pendingControlImages.includes(id) && processorType !== 'none'; + useEffect(() => { + if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) { + handleResetControlImage(); + } + }, [ + handleResetControlImage, + isConnected, + isErrorControlImage, + isErrorProcessedControlImage, + ]); + return ( { icon={customStarUi ? customStarUi.off.icon : } onClickCapture={handleUnstarImage} > - {customStarUi ? customStarUi.off.text : t('controlnet.unstarImage')} + {customStarUi ? customStarUi.off.text : t('gallery.unstarImage')} ) : ( } onClickCapture={handleStarImage} > - {customStarUi ? customStarUi.on.text : `Star Image`} + {customStarUi ? customStarUi.on.text : t('gallery.starImage')} )} { recallNegativePrompt, recallSeed, recallCfgScale, + recallCfgRescaleMultiplier, recallModel, recallScheduler, recallVaeModel, @@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => { recallCfgScale(metadata?.cfg_scale); }, [metadata?.cfg_scale, recallCfgScale]); + const handleRecallCfgRescaleMultiplier = useCallback(() => { + recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier); + }, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]); + const handleRecallStrength = useCallback(() => { recallStrength(metadata?.strength); }, [metadata?.strength, recallStrength]); @@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => { onClick={handleRecallCfgScale} /> )} + {metadata.cfg_rescale_multiplier !== undefined && + metadata.cfg_rescale_multiplier !== null && ( + + )} {metadata.strength && ( { const { nodeId, field } = props; const dispatch = useAppDispatch(); - - const { currentData: imageDTO } = useGetImageDTOQuery( + const isConnected = useAppSelector((state) => state.system.isConnected); + const { currentData: imageDTO, isError } = useGetImageDTOQuery( field.value?.image_name ?? skipToken ); @@ -67,6 +67,12 @@ const ImageFieldInputComponent = ( [nodeId, field.name] ); + useEffect(() => { + if (isConnected && isError) { + handleReset(); + } + }, [handleReset, isConnected, isError]); + return ( = // valid `any`! @@ -321,7 +326,28 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder< const buildEnumFieldInputTemplate: FieldInputTemplateBuilder< EnumFieldInputTemplate > = ({ schemaObject, baseField, isCollection, isCollectionOrScalar }) => { - const options = schemaObject.enum ?? []; + let options: EnumFieldInputTemplate['options'] = []; + if (schemaObject.anyOf) { + const filteredAnyOf = schemaObject.anyOf.filter((i) => { + if (isSchemaObject(i)) { + if (i.type === 'null') { + return false; + } + } + return true; + }); + const firstAnyOf = filteredAnyOf[0]; + if (filteredAnyOf.length !== 1 || !isSchemaObject(firstAnyOf)) { + options = []; + } else { + options = firstAnyOf.enum ?? []; + } + } else { + options = schemaObject.enum ?? []; + } + if (options.length === 0) { + throw new FieldParseError(t('nodes.unableToExtractEnumOptions')); + } const template: EnumFieldInputTemplate = { ...baseField, type: { diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 3b6fadd8a1..4ee4edce1b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -1,10 +1,4 @@ -import { t } from 'i18next'; -import { isArray } from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { - FieldTypeParseError, - UnsupportedFieldTypeError, -} from 'features/nodes/types/error'; +import { FieldParseError } from 'features/nodes/types/error'; import { FieldType } from 'features/nodes/types/field'; import { OpenAPIV3_1SchemaOrRef, @@ -14,6 +8,9 @@ import { isRefObject, isSchemaObject, } from 'features/nodes/types/openapi'; +import { t } from 'i18next'; +import { isArray } from 'lodash-es'; +import { OpenAPIV3_1 } from 'openapi-types'; /** * Transforms an invocation output ref object to field type. @@ -70,7 +67,7 @@ export const parseFieldType = ( // This is a single ref type const name = refObjectToSchemaName(allOf[0]); if (!name) { - throw new FieldTypeParseError( + throw new FieldParseError( t('nodes.unableToExtractSchemaNameFromRef') ); } @@ -95,7 +92,7 @@ export const parseFieldType = ( if (isRefObject(filteredAnyOf[0])) { const name = refObjectToSchemaName(filteredAnyOf[0]); if (!name) { - throw new FieldTypeParseError( + throw new FieldParseError( t('nodes.unableToExtractSchemaNameFromRef') ); } @@ -120,7 +117,7 @@ export const parseFieldType = ( if (filteredAnyOf.length !== 2) { // This is a union of more than 2 types, which we don't support - throw new UnsupportedFieldTypeError( + throw new FieldParseError( t('nodes.unsupportedAnyOfLength', { count: filteredAnyOf.length, }) @@ -167,7 +164,7 @@ export const parseFieldType = ( }; } - throw new UnsupportedFieldTypeError( + throw new FieldParseError( t('nodes.unsupportedMismatchedUnion', { firstType, secondType, @@ -186,7 +183,7 @@ export const parseFieldType = ( if (isSchemaObject(schemaObject.items)) { const itemType = schemaObject.items.type; if (!itemType || isArray(itemType)) { - throw new UnsupportedFieldTypeError( + throw new FieldParseError( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -196,7 +193,7 @@ export const parseFieldType = ( const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new UnsupportedFieldTypeError( + throw new FieldParseError( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -212,7 +209,7 @@ export const parseFieldType = ( // This is a ref object, extract the type name const name = refObjectToSchemaName(schemaObject.items); if (!name) { - throw new FieldTypeParseError( + throw new FieldParseError( t('nodes.unableToExtractSchemaNameFromRef') ); } @@ -226,7 +223,7 @@ export const parseFieldType = ( const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new UnsupportedFieldTypeError( + throw new FieldParseError( t('nodes.unsupportedArrayItemType', { type: schemaObject.type, }) @@ -242,9 +239,7 @@ export const parseFieldType = ( } else if (isRefObject(schemaObject)) { const name = refObjectToSchemaName(schemaObject); if (!name) { - throw new FieldTypeParseError( - t('nodes.unableToExtractSchemaNameFromRef') - ); + throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -252,5 +247,5 @@ export const parseFieldType = ( isCollectionOrScalar: false, }; } - throw new FieldTypeParseError(t('nodes.unableToParseFieldType')); + throw new FieldParseError(t('nodes.unableToParseFieldType')); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 94dca71048..9ad391d7c3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -1,12 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { t } from 'i18next'; -import { reduce } from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { - FieldTypeParseError, - UnsupportedFieldTypeError, -} from 'features/nodes/types/error'; +import { FieldParseError } from 'features/nodes/types/error'; import { FieldInputTemplate, FieldOutputTemplate, @@ -18,6 +12,9 @@ import { isInvocationOutputSchemaObject, isInvocationSchemaObject, } from 'features/nodes/types/openapi'; +import { t } from 'i18next'; +import { reduce } from 'lodash-es'; +import { OpenAPIV3_1 } from 'openapi-types'; import { buildFieldInputTemplate } from './buildFieldInputTemplate'; import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; import { parseFieldType } from './parseFieldType'; @@ -126,10 +123,7 @@ export const parseSchema = ( inputsAccumulator[propertyName] = fieldInputTemplate; } catch (e) { - if ( - e instanceof FieldTypeParseError || - e instanceof UnsupportedFieldTypeError - ) { + if (e instanceof FieldParseError) { logger('nodes').warn( { node: type, @@ -218,10 +212,7 @@ export const parseSchema = ( outputsAccumulator[propertyName] = fieldOutputTemplate; } catch (e) { - if ( - e instanceof FieldTypeParseError || - e instanceof UnsupportedFieldTypeError - ) { + if (e instanceof FieldParseError) { logger('nodes').warn( { node: type, diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse.tsx index 524076162f..7190fb9a58 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse.tsx @@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next'; import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise'; import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless'; import ParamClipSkip from './ParamClipSkip'; +import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier'; const selector = createSelector( stateSelector, (state: RootState) => { - const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = - state.generation; + const { + clipSkip, + model, + seamlessXAxis, + seamlessYAxis, + shouldUseCpuNoise, + cfgRescaleMultiplier, + } = state.generation; - return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise }; + return { + clipSkip, + model, + seamlessXAxis, + seamlessYAxis, + shouldUseCpuNoise, + cfgRescaleMultiplier, + }; }, defaultSelectorOptions ); export default function ParamAdvancedCollapse() { - const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = - useAppSelector(selector); + const { + clipSkip, + model, + seamlessXAxis, + seamlessYAxis, + shouldUseCpuNoise, + cfgRescaleMultiplier, + } = useAppSelector(selector); const { t } = useTranslation(); const activeLabel = useMemo(() => { const activeLabel: string[] = []; @@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() { activeLabel.push(t('parameters.seamlessY')); } + if (cfgRescaleMultiplier) { + activeLabel.push(t('parameters.cfgRescale')); + } + return activeLabel.join(', '); - }, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]); + }, [ + cfgRescaleMultiplier, + clipSkip, + model, + seamlessXAxis, + seamlessYAxis, + shouldUseCpuNoise, + t, + ]); return ( @@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() { )} + + ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamCFGRescaleMultiplier.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamCFGRescaleMultiplier.tsx new file mode 100644 index 0000000000..2a65b32028 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamCFGRescaleMultiplier.tsx @@ -0,0 +1,60 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; +import IAISlider from 'common/components/IAISlider'; +import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +const selector = createSelector( + [stateSelector], + ({ generation, hotkeys }) => { + const { cfgRescaleMultiplier } = generation; + const { shift } = hotkeys; + + return { + cfgRescaleMultiplier, + shift, + }; + }, + defaultSelectorOptions +); + +const ParamCFGRescaleMultiplier = () => { + const { cfgRescaleMultiplier, shift } = useAppSelector(selector); + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const handleChange = useCallback( + (v: number) => dispatch(setCfgRescaleMultiplier(v)), + [dispatch] + ); + + const handleReset = useCallback( + () => dispatch(setCfgRescaleMultiplier(0)), + [dispatch] + ); + + return ( + + + + ); +}; + +export default memo(ParamCFGRescaleMultiplier); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx index 52eac42700..5f61de5b7d 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { stateSelector } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; @@ -9,25 +9,30 @@ import { TypesafeDraggableData, TypesafeDroppableData, } from 'features/dnd/types'; -import { memo, useMemo } from 'react'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { memo, useEffect, useMemo } from 'react'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; const selector = createSelector( [stateSelector], (state) => { const { initialImage } = state.generation; + const { isConnected } = state.system; + return { initialImage, isResetButtonDisabled: !initialImage, + isConnected, }; }, defaultSelectorOptions ); const InitialImage = () => { - const { initialImage } = useAppSelector(selector); + const dispatch = useAppDispatch(); + const { initialImage, isConnected } = useAppSelector(selector); - const { currentData: imageDTO } = useGetImageDTOQuery( + const { currentData: imageDTO, isError } = useGetImageDTOQuery( initialImage?.imageName ?? skipToken ); @@ -49,6 +54,13 @@ const InitialImage = () => { [] ); + useEffect(() => { + if (isError && isConnected) { + // The image doesn't exist, reset init image + dispatch(clearInitialImage()); + } + }, [dispatch, isConnected, isError]); + return ( { [dispatch, parameterSetToast, parameterNotSetToast] ); + /** + * Recall CFG rescale multiplier with toast + */ + const recallCfgRescaleMultiplier = useCallback( + (cfgRescaleMultiplier: unknown) => { + if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { + parameterNotSetToast(); + return; + } + dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + /** * Recall model with toast */ @@ -799,6 +816,7 @@ export const useRecallParameters = () => { const { cfg_scale, + cfg_rescale_multiplier, height, model, positive_prompt, @@ -831,6 +849,10 @@ export const useRecallParameters = () => { dispatch(setCfgScale(cfg_scale)); } + if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { + dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + } + if (isParameterModel(model)) { dispatch(modelSelected(model)); } @@ -985,6 +1007,7 @@ export const useRecallParameters = () => { recallSDXLNegativeStylePrompt, recallSeed, recallCfgScale, + recallCfgRescaleMultiplier, recallModel, recallScheduler, recallVaeModel, diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 8b7b8cb487..49835601d2 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -24,6 +24,7 @@ import { ParameterVAEModel, ParameterWidth, zParameterModel, + ParameterCFGRescaleMultiplier, } from 'features/parameters/types/parameterSchemas'; export interface GenerationState { @@ -31,6 +32,7 @@ export interface GenerationState { hrfStrength: ParameterStrength; hrfMethod: ParameterHRFMethod; cfgScale: ParameterCFGScale; + cfgRescaleMultiplier: ParameterCFGRescaleMultiplier; height: ParameterHeight; img2imgStrength: ParameterStrength; infillMethod: string; @@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = { hrfEnabled: false, hrfMethod: 'ESRGAN', cfgScale: 7.5, + cfgRescaleMultiplier: 0, height: 512, img2imgStrength: 0.75, infillMethod: 'patchmatch', @@ -145,9 +148,15 @@ export const generationSlice = createSlice({ state.steps ); }, - setCfgScale: (state, action: PayloadAction) => { + setCfgScale: (state, action: PayloadAction) => { state.cfgScale = action.payload; }, + setCfgRescaleMultiplier: ( + state, + action: PayloadAction + ) => { + state.cfgRescaleMultiplier = action.payload; + }, setThreshold: (state, action: PayloadAction) => { state.threshold = action.payload; }, @@ -336,6 +345,7 @@ export const { resetParametersState, resetSeed, setCfgScale, + setCfgRescaleMultiplier, setWidth, setHeight, toggleSize, diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 99f58f721c..73e7d7d2c3 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale => zParameterCFGScale.safeParse(val).success; // #endregion +// #region CFG Rescale Multiplier +export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1); +export type ParameterCFGRescaleMultiplier = z.infer< + typeof zParameterCFGRescaleMultiplier +>; +export const isParameterCFGRescaleMultiplier = ( + val: unknown +): val is ParameterCFGRescaleMultiplier => + zParameterCFGRescaleMultiplier.safeParse(val).success; +// #endregion + // #region Scheduler export const zParameterScheduler = zSchedulerField; export type ParameterScheduler = z.infer; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index e1eeb19df3..7841a94d3f 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -14,11 +14,11 @@ import { } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { VALID_LOG_LEVELS } from 'app/logging/logger'; -import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useClearStorage } from 'common/hooks/useClearStorage'; import { consoleLogLevelChanged, setEnableImageDebugging, @@ -164,20 +164,14 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => { shouldEnableInformationalPopovers, } = useAppSelector(selector); + const clearStorage = useClearStorage(); + const handleClickResetWebUI = useCallback(() => { - // Only remove our keys - Object.keys(window.localStorage).forEach((key) => { - if ( - LOCALSTORAGE_KEYS.includes(key) || - key.startsWith(LOCALSTORAGE_PREFIX) - ) { - localStorage.removeItem(key); - } - }); + clearStorage(); onSettingsModalClose(); onRefreshModalOpen(); setInterval(() => setCountdown((prev) => prev - 1), 1000); - }, [onSettingsModalClose, onRefreshModalOpen]); + }, [clearStorage, onSettingsModalClose, onRefreshModalOpen]); useEffect(() => { if (countdown <= 0) { diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryContent.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryContent.tsx index 04e51397b0..cd3dccc464 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryContent.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryContent.tsx @@ -1,11 +1,8 @@ import WorkflowLibraryList from 'features/workflowLibrary/components/WorkflowLibraryList'; import WorkflowLibraryListWrapper from 'features/workflowLibrary/components/WorkflowLibraryListWrapper'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; const WorkflowLibraryContent = () => { - const { t } = useTranslation(); - return ( diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 9d323c456f..c0d54618f4 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1172,6 +1172,79 @@ export type components = { */ type: "infill_cv2"; }; + /** + * Calculate Image Tiles + * @description Calculate the coordinates and overlaps of tiles that cover a target image shape. + */ + CalculateImageTilesInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Image Width + * @description The image width, in pixels, to calculate tiles for. + * @default 1024 + */ + image_width?: number; + /** + * Image Height + * @description The image height, in pixels, to calculate tiles for. + * @default 1024 + */ + image_height?: number; + /** + * Tile Width + * @description The tile width, in pixels. + * @default 576 + */ + tile_width?: number; + /** + * Tile Height + * @description The tile height, in pixels. + * @default 576 + */ + tile_height?: number; + /** + * Overlap + * @description The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount + * @default 128 + */ + overlap?: number; + /** + * type + * @default calculate_image_tiles + * @constant + */ + type: "calculate_image_tiles"; + }; + /** CalculateImageTilesOutput */ + CalculateImageTilesOutput: { + /** + * Tiles + * @description The tiles coordinates that cover a particular image shape. + */ + tiles: components["schemas"]["Tile"][]; + /** + * type + * @default calculate_image_tiles_output + * @constant + */ + type: "calculate_image_tiles_output"; + }; /** * CancelByBatchIDsResult * @description Result of canceling by list of batch ids @@ -1228,6 +1301,61 @@ export type components = { */ type: "canny_image_processor"; }; + /** + * Paste Image + * @description Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image. + */ + CenterPadCropInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description The image to crop */ + image?: components["schemas"]["ImageField"]; + /** + * Left + * @description Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards) + * @default 0 + */ + left?: number; + /** + * Right + * @description Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards) + * @default 0 + */ + right?: number; + /** + * Top + * @description Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards) + * @default 0 + */ + top?: number; + /** + * Bottom + * @description Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards) + * @default 0 + */ + bottom?: number; + /** + * type + * @default img_paste + * @constant + */ + type: "img_paste"; + }; /** * ClearResult * @description Result of clearing the session queue @@ -2093,6 +2221,11 @@ export type components = { * @description The classifier-free guidance scale parameter */ cfg_scale?: number | null; + /** + * Cfg Rescale Multiplier + * @description Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR + */ + cfg_rescale_multiplier?: number | null; /** * Steps * @description The number of steps used for inference @@ -2264,6 +2397,58 @@ export type components = { */ type: "create_denoise_mask"; }; + /** + * Crop Latents + * @description Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be + * divisible by the latent scale factor of 8. + */ + CropLatentsCoreInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description Latents tensor */ + latents?: components["schemas"]["LatentsField"]; + /** + * X + * @description The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space. + */ + x?: number; + /** + * Y + * @description The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space. + */ + y?: number; + /** + * Width + * @description The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space. + */ + width?: number; + /** + * Height + * @description The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space. + */ + height?: number; + /** + * type + * @default crop_latents + * @constant + */ + type: "crop_latents"; + }; /** CursorPaginatedResults[SessionQueueItemDTO] */ CursorPaginatedResults_SessionQueueItemDTO_: { /** @@ -2416,6 +2601,12 @@ export type components = { * @description T2I-Adapter(s) to apply */ t2i_adapter?: components["schemas"]["T2IAdapterField"] | components["schemas"]["T2IAdapterField"][] | null; + /** + * Cfg Rescale Multiplier + * @description Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR + * @default 0 + */ + cfg_rescale_multiplier?: number; /** @description Latents tensor */ latents?: components["schemas"]["LatentsField"] | null; /** @description The mask to use for the operation */ @@ -3243,7 +3434,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"]; + [key: string]: components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageWatermarkInvocation"]; }; /** * Edges @@ -3280,7 +3471,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["T2IAdapterOutput"]; + [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ColorCollectionOutput"]; }; /** * Errors @@ -4289,7 +4480,7 @@ export type components = { type: "image_output"; }; /** - * Paste Image + * Center Pad or Crop Image * @description Pastes an image into another image. */ ImagePasteInvocation: { @@ -4338,10 +4529,10 @@ export type components = { crop?: boolean; /** * type - * @default img_paste + * @default img_pad_crop * @constant */ - type: "img_paste"; + type: "img_pad_crop"; }; /** * ImageRecordChanges @@ -5962,6 +6153,47 @@ export type components = { */ merge_dest_directory?: string | null; }; + /** + * Merge Tiles to Image + * @description Merge multiple tile images into a single image. + */ + MergeTilesToImageInvocation: { + /** @description Optional metadata to be saved with the image */ + metadata?: components["schemas"]["MetadataField"] | null; + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Tiles With Images + * @description A list of tile images with tile properties. + */ + tiles_with_images?: components["schemas"]["TileWithImage"][]; + /** + * Blend Amount + * @description The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles. + */ + blend_amount?: number; + /** + * type + * @default merge_tiles_to_image + * @constant + */ + type: "merge_tiles_to_image"; + }; /** * MetadataField * @description Pydantic model for metadata with custom root of type dict[str, Any]. @@ -6929,6 +7161,50 @@ export type components = { */ items: components["schemas"]["WorkflowRecordListItemDTO"][]; }; + /** + * Pair Tile with Image + * @description Pair an image with its tile properties. + */ + PairTileImageInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description The tile image. */ + image?: components["schemas"]["ImageField"]; + /** @description The tile properties. */ + tile?: components["schemas"]["Tile"]; + /** + * type + * @default pair_tile_image + * @constant + */ + type: "pair_tile_image"; + }; + /** PairTileImageOutput */ + PairTileImageOutput: { + /** @description A tile description with its corresponding image. */ + tile_with_image: components["schemas"]["TileWithImage"]; + /** + * type + * @default pair_tile_image_output + * @constant + */ + type: "pair_tile_image_output"; + }; /** * PIDI Processor * @description Applies PIDI processing to image @@ -9074,6 +9350,17 @@ export type components = { */ source?: string | null; }; + /** TBLR */ + TBLR: { + /** Top */ + top: number; + /** Bottom */ + bottom: number; + /** Left */ + left: number; + /** Right */ + right: number; + }; /** * TextualInversionConfig * @description Model config for textual inversion embeddings. @@ -9138,6 +9425,13 @@ export type components = { model_format: null; error?: components["schemas"]["ModelError"] | null; }; + /** Tile */ + Tile: { + /** @description The coordinates of this tile relative to its parent image. */ + coords: components["schemas"]["TBLR"]; + /** @description The amount of overlap with adjacent tiles on each side of this tile. */ + overlap: components["schemas"]["TBLR"]; + }; /** * Tile Resample Processor * @description Tile resampler processor @@ -9177,6 +9471,101 @@ export type components = { */ type: "tile_image_processor"; }; + /** + * Tile to Properties + * @description Split a Tile into its individual properties. + */ + TileToPropertiesInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description The tile to split into properties. */ + tile?: components["schemas"]["Tile"]; + /** + * type + * @default tile_to_properties + * @constant + */ + type: "tile_to_properties"; + }; + /** TileToPropertiesOutput */ + TileToPropertiesOutput: { + /** + * Coords Left + * @description Left coordinate of the tile relative to its parent image. + */ + coords_left: number; + /** + * Coords Right + * @description Right coordinate of the tile relative to its parent image. + */ + coords_right: number; + /** + * Coords Top + * @description Top coordinate of the tile relative to its parent image. + */ + coords_top: number; + /** + * Coords Bottom + * @description Bottom coordinate of the tile relative to its parent image. + */ + coords_bottom: number; + /** + * Width + * @description The width of the tile. Equal to coords_right - coords_left. + */ + width: number; + /** + * Height + * @description The height of the tile. Equal to coords_bottom - coords_top. + */ + height: number; + /** + * Overlap Top + * @description Overlap between this tile and its top neighbor. + */ + overlap_top: number; + /** + * Overlap Bottom + * @description Overlap between this tile and its bottom neighbor. + */ + overlap_bottom: number; + /** + * Overlap Left + * @description Overlap between this tile and its left neighbor. + */ + overlap_left: number; + /** + * Overlap Right + * @description Overlap between this tile and its right neighbor. + */ + overlap_right: number; + /** + * type + * @default tile_to_properties_output + * @constant + */ + type: "tile_to_properties_output"; + }; + /** TileWithImage */ + TileWithImage: { + tile: components["schemas"]["Tile"]; + image: components["schemas"]["ImageField"]; + }; /** UNetField */ UNetField: { /** @description Info to load unet submodel */ @@ -9500,6 +9889,8 @@ export type components = { * @description The version of the workflow schema. */ version: string; + /** @description The category of the workflow (user or system). */ + category: components["schemas"]["WorkflowCategory"]; }; /** WorkflowRecordDTO */ WorkflowRecordDTO: { @@ -9528,8 +9919,6 @@ export type components = { * @description The opened timestamp of the workflow. */ opened_at: string; - /** @description The category of the workflow (user or system). */ - category: components["schemas"]["WorkflowCategory"]; /** @description The workflow. */ workflow: components["schemas"]["Workflow"]; }; @@ -9560,8 +9949,6 @@ export type components = { * @description The opened timestamp of the workflow. */ opened_at: string; - /** @description The category of the workflow (user or system). */ - category: components["schemas"]["WorkflowCategory"]; /** * Description * @description The description of the workflow. @@ -9875,18 +10262,6 @@ export type components = { * @enum {string} */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; - /** - * IPAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - IPAdapterModelFormat: "invokeai"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. @@ -9894,11 +10269,11 @@ export type components = { */ StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusion1ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -9906,11 +10281,11 @@ export type components = { */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + T2IAdapterModelFormat: "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. @@ -9918,11 +10293,23 @@ export type components = { */ CLIPVisionModelFormat: "diffusers"; /** - * T2IAdapterModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * ControlNetModelFormat + * @description An enumeration. + * @enum {string} + */ + ControlNetModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/yarn.lock b/invokeai/frontend/web/yarn.lock index e0a9db1c5e..6c661af24b 100644 --- a/invokeai/frontend/web/yarn.lock +++ b/invokeai/frontend/web/yarn.lock @@ -4158,6 +4158,11 @@ i18next@^23.6.0: dependencies: "@babel/runtime" "^7.22.5" +idb-keyval@^6.2.1: + version "6.2.1" + resolved "https://registry.yarnpkg.com/idb-keyval/-/idb-keyval-6.2.1.tgz#94516d625346d16f56f3b33855da11bfded2db33" + integrity sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg== + ieee754@^1.1.13: version "1.2.1" resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352" diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6712196778..6a3ec510a2 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device): "unet_model_id": "runwayml/stable-diffusion-v1-5", "unet_model_name": "stable-diffusion-v1-5", }, + # SD1.5, IPAdapterFull + { + "ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15", + "ip_adapter_model_name": "ip-adapter-full-face_sd15", + "base_model": BaseModelType.StableDiffusion1, + "unet_model_id": "runwayml/stable-diffusion-v1-5", + "unet_model_name": "stable-diffusion-v1-5", + }, ], ) @pytest.mark.slow diff --git a/tests/backend/tiles/test_tiles.py b/tests/backend/tiles/test_tiles.py new file mode 100644 index 0000000000..353e65d336 --- /dev/null +++ b/tests/backend/tiles/test_tiles.py @@ -0,0 +1,224 @@ +import numpy as np +import pytest + +from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.utils import TBLR, Tile + +#################################### +# Test calc_tiles_with_overlap(...) +#################################### + + +def test_calc_tiles_with_overlap_single_tile(): + """Test calc_tiles_with_overlap() behavior when a single tile covers the image.""" + tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64) + + expected_tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0)) + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles.""" + # Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles. + tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)), + # Row 1 + Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)), + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_not_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage.""" + # Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps. + tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)), + Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)), + # Row 1 + Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)), + Tile( + coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272) + ), + Tile( + coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0) + ), + ] + + assert tiles == expected_tiles + + +@pytest.mark.parametrize( + ["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"], + [ + (128, 128, 128, 128, 127, False), # OK + (128, 128, 128, 128, 0, False), # OK + (128, 128, 64, 64, 0, False), # OK + (128, 128, 129, 128, 0, True), # tile_height exceeds image_height. + (128, 128, 128, 129, 0, True), # tile_width exceeds image_width. + (128, 128, 64, 128, 64, True), # overlap equals tile_height. + (128, 128, 128, 64, 64, True), # overlap equals tile_width. + ], +) +def test_calc_tiles_with_overlap_input_validation( + image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool +): + """Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid.""" + if raises: + with pytest.raises(AssertionError): + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) + else: + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) + + +############################################# +# Test merge_tiles_with_linear_blending(...) +############################################# + + +@pytest.mark.parametrize("blend_amount", [0, 32]) +def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int): + """Test merge_tiles_with_linear_blending(...) behavior when merging horizontally.""" + # Initialize 2 tiles side-by-side. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + ] + + dst_image = np.zeros((512, 960, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. Pixel values are set based on the tile index. + tile_images = [ + np.zeros((512, 512, 3)) + 64, + np.zeros((512, 512, 3)) + 128, + ] + + # Calculate expected output. + expected_output = np.zeros((512, 960, 3), dtype=np.uint8) + expected_output[:, : 480 - (blend_amount // 2), :] = 64 + if blend_amount > 0: + gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((1, blend_amount, 1)) + expected_output[:, 480 - (blend_amount // 2) : 480 + (blend_amount // 2), :] = gradient + expected_output[:, 480 + (blend_amount // 2) :, :] = 128 + + merge_tiles_with_linear_blending( + dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + ) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +@pytest.mark.parametrize("blend_amount", [0, 32]) +def test_merge_tiles_with_linear_blending_vertical(blend_amount: int): + """Test merge_tiles_with_linear_blending(...) behavior when merging vertically.""" + # Initialize 2 tiles stacked vertically. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), + Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + ] + + dst_image = np.zeros((960, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. Pixel values are set based on the tile index. + tile_images = [ + np.zeros((512, 512, 3)) + 64, + np.zeros((512, 512, 3)) + 128, + ] + + # Calculate expected output. + expected_output = np.zeros((960, 512, 3), dtype=np.uint8) + expected_output[: 480 - (blend_amount // 2), :, :] = 64 + if blend_amount > 0: + gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((blend_amount, 1, 1)) + expected_output[480 - (blend_amount // 2) : 480 + (blend_amount // 2), :, :] = gradient + expected_output[480 + (blend_amount // 2) :, :, :] = 128 + + merge_tiles_with_linear_blending( + dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + ) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between + any vertically adjacent tiles. + """ + # Initialize 2 tiles stacked vertically. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), + Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + ] + + dst_image = np.zeros((960, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + # blend_amount=128 exceeds overlap of 64, so should raise exception. + with pytest.raises(AssertionError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128) + + +def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overlap(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between + any horizontally adjacent tiles. + """ + # Initialize 2 tiles side-by-side. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + ] + + dst_image = np.zeros((512, 960, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + # blend_amount=128 exceeds overlap of 64, so should raise exception. + with pytest.raises(AssertionError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128) + + +def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the + dst_image. + """ + tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + + dst_image = np.zeros((256, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3))] + + with pytest.raises(ValueError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0) + + +def test_merge_tiles_with_linear_blending_mismatched_list_lengths(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images' + do not match. + """ + tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + + dst_image = np.zeros((256, 512, 3), dtype=np.uint8) + + # tile_images is longer than tiles, so should cause an exception. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + with pytest.raises(ValueError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0) diff --git a/tests/backend/tiles/test_utils.py b/tests/backend/tiles/test_utils.py new file mode 100644 index 0000000000..bbef233ca5 --- /dev/null +++ b/tests/backend/tiles/test_utils.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from invokeai.backend.tiles.utils import TBLR, paste + + +def test_paste_no_mask_success(): + """Test successful paste with mask=None.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 1 # Row of 1s in channel 0. + src_image[:, 0, 1] = 2 # Column of 2s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1:4, 0] = 1 + expected_output[2:5, 1, 1] = 2 + + paste(dst_image=dst_image, src_image=src_image, box=box) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +def test_paste_with_mask_success(): + """Test successful paste with a mask.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 64 # Row of 64s in channel 0. + src_image[:, 0, 1] = 128 # Column of 128s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image. + mask = np.zeros((3, 3)) + mask[0, 0] = 0.5 + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1, 0] = 32 + expected_output[2, 1, 1] = 64 + + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_box_overflows_dst_image(use_mask: bool): + """Test that an exception is raised if 'box' overflows the 'dst_image'.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that overflows bottom of dst_image. + top = 3 + left = 0 + box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_src_image_does_not_match_box(use_mask: bool): + """Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that is smaller than src_image. + box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +def test_paste_mask_does_not_match_src_image(): + """Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + + # Construct mask that is smaller than the src_image. + mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1])) + + # Construct box that matches src_image shape. + box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)