From 1fb307abf41a43363ea37ffb624f8b62b0bf61aa Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 25 May 2023 00:28:04 +1000 Subject: [PATCH] feat(nodes): restore canvas functionality (non-latents) --- invokeai/app/invocations/generate.py | 146 +++++++++++++-------------- 1 file changed, 71 insertions(+), 75 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 3b3e5512c7..aa16243093 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -14,14 +14,17 @@ from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput, build_image_output +from .image import ImageOutput from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.stable_diffusion import PipelineIntermediateState from ..util.step_callback import stable_diffusion_step_callback SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] -DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile' +DEFAULT_INFILL_METHOD = ( + "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" +) + class SDImageInvocation(BaseModel): """Helper class to provide all Stable Diffusion raster image invocations with additional config""" @@ -92,7 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # each time it is called. We only need the first one. generate_output = next(outputs) - image_dto = context.services.images_new.create( + image_dto = context.services.images.create( image=generate_output.image, image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, @@ -100,35 +103,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): node_id=self.id, ) - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - # image_type = ImageType.RESULT - # image_name = context.services.images.create_name( - # context.graph_execution_state_id, self.id - # ) - - # metadata = context.services.metadata.build_metadata( - # session_id=context.graph_execution_state_id, node=self - # ) - - # context.services.images.save( - # image_type, image_name, generate_output.image, metadata - # ) - - # context.services.images_db.set( - # id=image_name, - # image_type=ImageType.RESULT, - # image_category=ImageCategory.GENERAL, - # session_id=context.graph_execution_state_id, - # node_id=self.id, - # metadata=GeneratedImageOrLatentsMetadata(), - # ) - - return build_image_output( - image_type=image_dto.image_type, - image_name=image_dto.image_name, - image=generate_output.image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -164,7 +145,7 @@ class ImageToImageInvocation(TextToImageInvocation): image = ( None if self.image is None - else context.services.images.get( + else context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) ) @@ -194,26 +175,23 @@ class ImageToImageInvocation(TextToImageInvocation): # each time it is called. We only need the first one. generator_output = next(outputs) - result_image = generator_output.image - - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=generator_output.image, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + session_id=context.graph_execution_state_id, + node_id=self.id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) - context.services.images.save(image_type, image_name, result_image, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=result_image, - ) class InpaintInvocation(ImageToImageInvocation): """Generates an image using inpaint.""" @@ -223,16 +201,38 @@ class InpaintInvocation(ImageToImageInvocation): # Inputs mask: Union[ImageField, None] = Field(description="The mask") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") - seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)") + seam_blur: int = Field( + default=16, ge=0, description="The seam inpaint blur radius (px)" + ) seam_strength: float = Field( default=0.75, gt=0, le=1, description="The seam inpaint strength" ) - seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint") - tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)") - infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)") - inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)") - inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)") - inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color") + seam_steps: int = Field( + default=30, ge=1, description="The number of steps to use for seam inpaint" + ) + tile_size: int = Field( + default=32, ge=1, description="The tile infill method size (px)" + ) + infill_method: INFILL_METHODS = Field( + default=DEFAULT_INFILL_METHOD, + description="The method used to infill empty regions (px)", + ) + inpaint_width: Optional[int] = Field( + default=None, + multiple_of=8, + gt=0, + description="The width of the inpaint region (px)", + ) + inpaint_height: Optional[int] = Field( + default=None, + multiple_of=8, + gt=0, + description="The height of the inpaint region (px)", + ) + inpaint_fill: Optional[ColorField] = Field( + default=ColorField(r=127, g=127, b=127, a=255), + description="The solid infill method color", + ) inpaint_replace: float = Field( default=0.0, ge=0.0, @@ -257,14 +257,14 @@ class InpaintInvocation(ImageToImageInvocation): image = ( None if self.image is None - else context.services.images.get( + else context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) ) mask = ( None if self.mask is None - else context.services.images.get(self.mask.image_type, self.mask.image_name) + else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) ) # Handle invalid model parameter @@ -290,23 +290,19 @@ class InpaintInvocation(ImageToImageInvocation): # each time it is called. We only need the first one. generator_output = next(outputs) - result_image = generator_output.image - - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=generator_output.image, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + session_id=context.graph_execution_state_id, + node_id=self.id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, result_image, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=result_image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, )