feat(nodes): wip inpaint node

This commit is contained in:
psychedelicious 2023-05-05 00:06:34 +10:00
parent 357cee2849
commit 206e6b1730
5 changed files with 40 additions and 13 deletions

View File

@ -1,15 +1,16 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial from functools import partial
from typing import Literal, Optional, Union from typing import Literal, Optional, Union, get_args
import numpy as np import numpy as np
from torch import Tensor from torch import Tensor
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@ -17,7 +18,9 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel): class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config""" """Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@ -45,7 +48,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# fmt: off # fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
@ -148,7 +151,6 @@ class ImageToImageInvocation(TextToImageInvocation):
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
) )
mask = None
if self.fit: if self.fit:
image = image.resize((self.width, self.height)) image = image.resize((self.width, self.height))
@ -165,7 +167,6 @@ class ImageToImageInvocation(TextToImageInvocation):
outputs = Img2Img(model).generate( outputs = Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
@ -197,7 +198,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image, image=result_image,
) )
class InpaintInvocation(ImageToImageInvocation): class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint.""" """Generates an image using inpaint."""
@ -205,6 +205,17 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs # Inputs
mask: Union[ImageField, None] = Field(description="The mask") mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field( inpaint_replace: float = Field(
default=0.0, default=0.0,
ge=0.0, ge=0.0,

View File

@ -27,3 +27,10 @@ class ImageField(BaseModel):
class Config: class Config:
schema_extra = {"required": ["image_type", "image_name"]} schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: Optional[int] = Field(default=255, ge=0, le=255, description="The alpha component")

View File

@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
latents_name: str latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports # TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[ NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField
] ]

View File

@ -226,10 +226,10 @@ class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image.Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 96,
seam_blur: int = 0, seam_blur: int = 16,
seam_strength: float = 0.7, seam_strength: float = 0.7,
seam_steps: int = 10, seam_steps: int = 30,
tile_size: int = 32, tile_size: int = 32,
inpaint_replace=False, inpaint_replace=False,
infill_method=None, infill_method=None,

View File

@ -211,10 +211,10 @@ class Inpaint(Img2Img):
strength: float, strength: float,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 96,
seam_blur: int = 0, seam_blur: int = 16,
seam_strength: float = 0.7, seam_strength: float = 0.7,
seam_steps: int = 10, seam_steps: int = 30,
tile_size: int = 32, tile_size: int = 32,
step_callback=None, step_callback=None,
inpaint_replace=False, inpaint_replace=False,