mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): wip inpaint node
This commit is contained in:
parent
357cee2849
commit
206e6b1730
@ -1,15 +1,16 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput, build_image_output
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
@ -17,7 +18,9 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
|||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
|
|
||||||
|
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
class SDImageInvocation(BaseModel):
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||||
@ -45,7 +48,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
@ -148,7 +151,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mask = None
|
|
||||||
|
|
||||||
if self.fit:
|
if self.fit:
|
||||||
image = image.resize((self.width, self.height))
|
image = image.resize((self.width, self.height))
|
||||||
@ -165,7 +167,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
outputs = Img2Img(model).generate(
|
outputs = Img2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
@ -197,7 +198,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image=result_image,
|
image=result_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@ -205,6 +205,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask: Union[ImageField, None] = Field(description="The mask")
|
mask: Union[ImageField, None] = Field(description="The mask")
|
||||||
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
|
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||||
|
seam_strength: float = Field(
|
||||||
|
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||||
|
)
|
||||||
|
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||||
|
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||||
|
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||||
|
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||||
|
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||||
|
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||||
inpaint_replace: float = Field(
|
inpaint_replace: float = Field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
|
@ -27,3 +27,10 @@ class ImageField(BaseModel):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_type", "image_name"]}
|
schema_extra = {"required": ["image_type", "image_name"]}
|
||||||
|
|
||||||
|
|
||||||
|
class ColorField(BaseModel):
|
||||||
|
r: int = Field(ge=0, le=255, description="The red component")
|
||||||
|
g: int = Field(ge=0, le=255, description="The green component")
|
||||||
|
b: int = Field(ge=0, le=255, description="The blue component")
|
||||||
|
a: Optional[int] = Field(default=255, ge=0, le=255, description="The alpha component")
|
||||||
|
@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
|
|||||||
latents_name: str
|
latents_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataColorField(TypedDict):
|
||||||
|
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||||
|
r: int
|
||||||
|
g: int
|
||||||
|
b: int
|
||||||
|
a: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||||
NodeMetadata = Dict[
|
NodeMetadata = Dict[
|
||||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -226,10 +226,10 @@ class Inpaint(Img2Img):
|
|||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image.Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 96,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 16,
|
||||||
seam_strength: float = 0.7,
|
seam_strength: float = 0.7,
|
||||||
seam_steps: int = 10,
|
seam_steps: int = 30,
|
||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
inpaint_replace=False,
|
inpaint_replace=False,
|
||||||
infill_method=None,
|
infill_method=None,
|
||||||
|
@ -211,10 +211,10 @@ class Inpaint(Img2Img):
|
|||||||
strength: float,
|
strength: float,
|
||||||
mask_blur_radius: int = 8,
|
mask_blur_radius: int = 8,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 96,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 16,
|
||||||
seam_strength: float = 0.7,
|
seam_strength: float = 0.7,
|
||||||
seam_steps: int = 10,
|
seam_steps: int = 30,
|
||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
step_callback=None,
|
step_callback=None,
|
||||||
inpaint_replace=False,
|
inpaint_replace=False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user