Added more preprocessor nodes for:

MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
This commit is contained in:
user1 2023-05-05 14:12:19 -07:00 committed by Kent Keirsey
parent 5acbbeecaa
commit 0e027ec3ef

View File

@ -1,10 +1,9 @@
# InvokeAI nodes for ControlNet image preprocessors # InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType from ..models.image import ImageField, ImageType
@ -26,83 +25,23 @@ from controlnet_aux import (
OpenposeDetector, OpenposeDetector,
PidiNetDetector, PidiNetDetector,
ContentShuffleDetector, ContentShuffleDetector,
ZoeDetector, ZoeDetector)
)
from .image import ImageOutput, build_image_output, PILInvocationConfig from .image import ImageOutput, build_image_output, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
"CrucibleAI/ControlNetMediaPipeFace",# SD 2.1?
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="processed image") image: ImageField = Field(default=None, description="processed image")
control_model: Optional[str] = Field(default=None, description="control model used") # width: Optional[int] = Field(default=None, description="The width of the image in pixels")
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") # height: Optional[int] = Field(default=None, description="The height of the image in pixels")
begin_step_percent: float = Field(default=0, ge=0, le=1, # mode: Optional[str] = Field(default=None, description="The mode of the image")
description="% of total steps at which controlnet is first applied") control_model: Optional[str] = Field(default=None, description="The control model used")
end_step_percent: float = Field(default=1, ge=0, le=1, control_weight: Optional[float] = Field(default=None, description="The control weight used")
description="% of total steps at which controlnet is last applied")
class Config: class Config:
schema_extra = { schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] "required": ["image", "control_model", "control_weight"]
# "required": ["type", "image", "width", "height", "mode"]
} }
@ -110,50 +49,29 @@ class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off # fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info dict") control: Optional[ControlField] = Field(default=None, description="The control info dict")
raw_processed_image: ImageField = Field(default=None,
description="outputs just the image info (also included in control output)")
# fmt: on # fmt: on
class PreprocessedControlInvocation(BaseInvocation, PILInvocationConfig): # This super class handles invoke() call, which in turn calls run_processor(image)
"""Base class for invocations that preprocess images for ControlNet""" # subclasses override run_processor() instead of implementing their own invoke()
class PreprocessedControlNetInvocation(BaseInvocation, PILInvocationConfig):
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="image to process")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
# fmt: on
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
# fmt: off # fmt: off
type: Literal["image_processor"] = "image_processor" type: Literal["preprocessed_control"] = "preprocessed_control"
# Inputs # Inputs
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="image to process")
control_model: str = Field(default=None, description="control model to use")
control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight")
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs)
# begin_step_percent: float = Field(default=0, ge=0, le=1,
# description="% of total steps at which controlnet is first applied")
# end_step_percent: float = Field(default=1, ge=0, le=1,
# description="% of total steps at which controlnet is last applied")
# guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)")
# fmt: on # fmt: on
@ -161,12 +79,12 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
# superclass just passes through image without processing # superclass just passes through image without processing
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
raw_image = context.services.images.get( image = context.services.images.get(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(image)
# currently can't see processed image in node UI without a showImage node, # currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
# image_type = ImageType.INTERMEDIATE # image_type = ImageType.INTERMEDIATE
@ -180,22 +98,24 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
context.services.images.save(image_type, image_name, processed_image, metadata) context.services.images.save(image_type, image_name, processed_image, metadata)
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( image_field = ImageField(
image_name=image_name, image_name=image_name,
image_type=image_type, image_type=image_type,
) )
return ImageOutput( return ControlOutput(
image=processed_image_field, control=ControlField(
width=processed_image.width, image=image_field,
height=processed_image.height, control_model=self.control_model,
mode=processed_image.mode, control_weight=self.control_weight,
),
raw_processed_image=image_field,
) )
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class CannyControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
# fmt: off # fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor" type: Literal["cannycontrol"] = "cannycontrol"
# Input # Input
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
@ -207,15 +127,14 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image return processed_image
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class HedControlNetInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
# fmt: off # fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor" type: Literal["hed_control"] = "hed_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# safe not supported in controlnet_aux v0.0.3 safe: bool = Field(default=False, description="whether to use safe mode")
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode") scribble: bool = Field(default=False, description="whether to use scribble mode")
# fmt: on # fmt: on
@ -224,17 +143,16 @@ class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
processed_image = hed_processor(image, processed_image = hed_processor(image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3 safe=self.safe,
# safe=self.safe,
scribble=self.scribble, scribble=self.scribble,
) )
return processed_image return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class LineartControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies line art processing to image""" """Applies line art processing to image"""
# fmt: off # fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor" type: Literal["lineart_control"] = "lineart_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
@ -250,10 +168,10 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
return processed_image return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class LineartAnimeControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
# fmt: off # fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" type: Literal["lineart_anime_control"] = "lineart_anime_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
@ -268,10 +186,10 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
return processed_image return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class OpenposeControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
# fmt: off # fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor" type: Literal["openpose_control"] = "openpose_control"
# Inputs # Inputs
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode") hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
@ -288,15 +206,14 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class MidasDepthControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
# fmt: off # fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" type: Literal["midas_control"] = "midas_control"
# Inputs # Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI") a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th") bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
# depth_and_normal not supported in controlnet_aux v0.0.3 depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -304,16 +221,14 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
processed_image = midas_processor(image, processed_image = midas_processor(image,
a=np.pi * self.a_mult, a=np.pi * self.a_mult,
bg_th=self.bg_th, bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3 depth_and_normal=self.depth_and_normal)
# depth_and_normal=self.depth_and_normal,
)
return processed_image return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class NormalbaeControlNetInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
# fmt: off # fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor" type: Literal["normalbae_control"] = "normalbae_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
@ -327,10 +242,10 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
return processed_image return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class MLSDControlNetInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
# fmt: off # fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor" type: Literal["mlsd_control"] = "mlsd_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
@ -348,10 +263,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class PidiControlNetInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
# fmt: off # fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor" type: Literal["pidi_control"] = "pidi_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
@ -369,16 +284,16 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class ContentShuffleControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
# fmt: off # fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" type: Literal["content_shuffle_control"] = "content_shuffle_control"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter") h: Union[int | None] = Field(default=None, ge=0, description="content shuffle h parameter")
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter") w: Union[int | None] = Field(default=None, ge=0, description="content shuffle w parameter")
f: Union[int | None] = Field(default=256, ge=0, description="cont") f: Union[int | None] = Field(default=None, ge=0, description="cont")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -393,10 +308,10 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
return processed_image return processed_image
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class ZoeDepthControlInvocation(PreprocessedControlNetInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
# fmt: off # fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" type: Literal["zoe_depth_control"] = "zoe_depth_control"
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):