fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to show")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image.show()
@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata):
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
image_dto = context.images.save(image=image)
@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata):
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation):
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Calculate and create new image dimensions
@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata):
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
mask = None
@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_mask = image.split()[-1]
@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
image1: ImageField = InputField(description="The first image to multiply")
image2: ImageField = InputField(description="The second image to multiply")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.images.get_pil(self.image1.image_name)
image2 = context.images.get_pil(self.image2.image_name)
@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
channel_image = image.getchannel(self.channel)
@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
converted_image = image.convert(self.mode)
@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
# Metadata
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
blur = (
@ -338,7 +339,7 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
def array_from_pil(self, img):
return numpy.array(img) / 255
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mode = image.mode
@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata):
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata):
)
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata):
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to check")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
logger = context.logger
@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to check")
text: str = InputField(default="InvokeAI", description="Watermark text")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
new_image = InvisibleWatermark.add_watermark(image, self.text)
image_dto = context.images.save(image=new_image)
@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata):
description="Second threshold for the hysteresis procedure in Canny edge detection"
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.images.get_pil(self.image.image_name).convert("L")
npimg = numpy.asarray(mask, dtype=numpy.uint8)
@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata):
mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = InputField(description="The second image to combine")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.images.get_pil(self.mask1.image_name).convert("L")
mask2 = context.images.get_pil(self.mask2.image_name).convert("L")
@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata):
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
if self.mask is not None:
pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L")
@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to adjust")
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# Convert image to HSV color space
@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description=FieldDescriptions.image)
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None)
@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description=FieldDescriptions.image)
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.get_dto(self.image.image_name)
image_dto = context.images.update(