From e0b9b5cc6c3af933beab6cfb0bbacc5fc269a4fd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 2 May 2023 20:09:56 +1000 Subject: [PATCH] feat(nodes): add dataURL to image node --- invokeai/app/invocations/image.py | 51 +++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 883ef63f69..7c1465f4e3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1,10 +1,12 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +import io from typing import Literal, Optional import numpy from PIL import Image, ImageFilter, ImageOps -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator +from w3lib.url import parse_data_uri from ..models.image import ImageField, ImageType from .baseinvocation import ( @@ -37,9 +39,7 @@ class ImageOutput(BaseInvocationOutput): # fmt: on class Config: - schema_extra = { - "required": ["type", "image", "width", "height", "mode"] - } + schema_extra = {"required": ["type", "image", "width", "height", "mode"]} def build_image_output( @@ -119,6 +119,45 @@ class ShowImageInvocation(BaseInvocation): ) +class DataURLToImageInvocation(BaseInvocation, PILInvocationConfig): + """Outputs an image from a data URL.""" + + type: Literal["dataURL_image"] = "dataURL_image" + + # Inputs + dataURL: str = Field(description="The b64 data URL") + + @validator("dataURL") + def must_be_valid_image_dataURL(cls, v): + try: + result = parse_data_uri(v) + img = Image.open(io.BytesIO(result.data)) + img.verify() + except Exception: + raise ValueError("Invalid image dataURL") + return v + + def invoke(self, context: InvocationContext) -> ImageOutput: + # TODO: Figure out how to use pydantic validator to also transform into a different type, + # bc this is just the same logic as we use to validate the dataURL. + result = parse_data_uri(self.dataURL) + image = Image.open(io.BytesIO(result.data)) + + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) + + metadata = context.services.metadata.build_metadata( + session_id=context.graph_execution_state_id, node=self + ) + + context.services.images.save(ImageType.RESULT, image_name, image, metadata) + + return build_image_output( + image_type=ImageType.RESULT, image_name=image_name, image=image + ) + + class CropImageInvocation(BaseInvocation, PILInvocationConfig): """Crops an image to a specified box. The box can be outside of the image.""" @@ -151,7 +190,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig): metadata = context.services.metadata.build_metadata( session_id=context.graph_execution_state_id, node=self ) - + context.services.images.save(image_type, image_name, image_crop, metadata) return build_image_output( image_type=image_type, @@ -209,7 +248,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig): metadata = context.services.metadata.build_metadata( session_id=context.graph_execution_state_id, node=self ) - + context.services.images.save(image_type, image_name, new_image, metadata) return build_image_output( image_type=image_type,