From c6235049c781175576bf7a65575f709c2b41e60a Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Wed, 13 Dec 2023 13:37:21 -0600 Subject: [PATCH] Add an unsharp mask node to core nodes Unsharp mask is an image operation that, despite its name, sharpens an image. Like a Gaussian blur, it takes a radius and strength. --- invokeai/app/invocations/image.py | 61 ++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 89209fed41..f1e750eda1 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -5,13 +5,12 @@ from typing import Literal, Optional import cv2 import numpy -from PIL import Image, ImageChops, ImageFilter, ImageOps - from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker +from PIL import Image, ImageChops, ImageFilter, ImageOps from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation @@ -421,6 +420,64 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): ) + +@invocation( + "unsharp_mask", + title="Unsharp Mask", + tags=["image", "unsharp_mask"], + category="image", + version="1.2.0", +) +class UnsharpMaskInvocation(BaseInvocation, WithMetadata): + """Applies an unsharp mask filter to an image""" + + image: ImageField = InputField(description="The image to use") + radius: float = InputField(gt=0, description="Unsharp mask radius", default=2) + strength: float = InputField(ge=0, description="Unsharp mask strength", default=50) + + def pil_from_array(self, arr): + return Image.fromarray((arr * 255).astype("uint8")) + + def array_from_pil(self, img): + return numpy.array(img) / 255 + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image(self.image.image_name) + mode = image.mode + + alpha_channel = image.getchannel("A") if mode == "RGBA" else None + image = image.convert("RGB") + image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius))) + + image = self.array_from_pil(image) + image += (image - image_blurred) * (self.strength / 100.0) + image = numpy.clip(image, 0, 1) + image = self.pil_from_array(image) + + image = image.convert(mode) + + # Make the image RGBA if we had a source alpha channel + if alpha_channel is not None: + image.putalpha(alpha_channel) + + image_dto = context.services.images.create( + image=image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + metadata=self.metadata, + workflow=context.workflow, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image.width, + height=image.height, + ) + + PIL_RESAMPLING_MODES = Literal[ "nearest", "box",