feat(nodes): add WithBoard field helper class

This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it.

This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions?

Supporting changes:
- `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board.
- Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`.
- Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629
- Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed.

Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part.

Note: A followup commit will implement the frontend changes to support this change.
This commit is contained in:
psychedelicious
2024-02-07 16:33:55 +11:00
parent e137071543
commit 7fbdfbf9e5
12 changed files with 78 additions and 134 deletions

View File

@ -8,12 +8,11 @@ import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.fields import (
BoardField,
ColorField,
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
@ -55,7 +54,7 @@ class ShowImageInvocation(BaseInvocation):
category="image",
version="1.2.1",
)
class BlankImageInvocation(BaseInvocation, WithMetadata):
class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Creates a blank image and forwards it to the pipeline"""
width: int = InputField(default=512, description="The width of the image")
@ -78,7 +77,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageCropInvocation(BaseInvocation, WithMetadata):
class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crops an image to a specified box. The box can be outside of the image."""
image: ImageField = InputField(description="The image to crop")
@ -149,7 +148,7 @@ class CenterPadCropInvocation(BaseInvocation):
category="image",
version="1.2.1",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata):
class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Pastes an image into another image."""
base_image: ImageField = InputField(description="The base image")
@ -196,7 +195,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from")
@ -221,7 +220,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
image1: ImageField = InputField(description="The first image to multiply")
@ -248,7 +247,7 @@ IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
category="image",
version="1.2.1",
)
class ImageChannelInvocation(BaseInvocation, WithMetadata):
class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Gets a channel from an image."""
image: ImageField = InputField(description="The image to get the channel from")
@ -274,7 +273,7 @@ IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F
category="image",
version="1.2.1",
)
class ImageConvertInvocation(BaseInvocation, WithMetadata):
class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Converts an image to a different mode."""
image: ImageField = InputField(description="The image to convert")
@ -297,7 +296,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageBlurInvocation(BaseInvocation, WithMetadata):
class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Blurs an image"""
image: ImageField = InputField(description="The image to blur")
@ -326,7 +325,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
version="1.2.1",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies an unsharp mask filter to an image"""
image: ImageField = InputField(description="The image to use")
@ -394,7 +393,7 @@ PIL_RESAMPLING_MAP = {
category="image",
version="1.2.1",
)
class ImageResizeInvocation(BaseInvocation, WithMetadata):
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Resizes an image to specific dimensions"""
image: ImageField = InputField(description="The image to resize")
@ -424,7 +423,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageScaleInvocation(BaseInvocation, WithMetadata):
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scales an image by a factor"""
image: ImageField = InputField(description="The image to scale")
@ -459,7 +458,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageLerpInvocation(BaseInvocation, WithMetadata):
class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -486,7 +485,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Inverse linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -513,7 +512,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images"""
image: ImageField = InputField(description="The image to check")
@ -548,7 +547,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add an invisible watermark to an image"""
image: ImageField = InputField(description="The image to check")
@ -569,7 +568,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies an edge mask to an image"""
image: ImageField = InputField(description="The image to apply the mask to")
@ -608,7 +607,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class MaskCombineInvocation(BaseInvocation, WithMetadata):
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
mask1: ImageField = InputField(description="The first mask to combine")
@ -632,7 +631,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
@ -736,7 +735,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adjusts the Hue of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -825,7 +824,7 @@ CHANNEL_FORMATS = {
category="image",
version="1.2.1",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add or subtract a value from a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -881,7 +880,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
category="image",
version="1.2.1",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scale a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -926,41 +925,14 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
version="1.2.1",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation, WithMetadata):
class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description=FieldDescriptions.image)
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None)
return ImageOutput.build(image_dto)
@invocation(
"linear_ui_output",
title="Linear UI Image Output",
tags=["primitives", "image"],
category="primitives",
version="1.0.2",
use_cache=False,
)
class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
"""Handles Linear UI Image Outputting tasks."""
image: ImageField = InputField(description=FieldDescriptions.image)
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.get_dto(self.image.image_name)
image_dto = context.images.update(
image_name=self.image.image_name,
board_id=self.board.board_id if self.board else None,
is_intermediate=self.is_intermediate,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)