mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
1062fc4796
Initial support for polymorphic field types. Polymorphic types are a single of or list of a specific type. For example, `Union[str, list[str]]`. Polymorphics do not yet have support for direct input in the UI (will come in the future). They will be forcibly set as Connection-only fields, in which case users will not be able to provide direct input to the field. If a polymorphic should present as a singleton type - which would allow direct input - the node must provide an explicit type hint. For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the node editor, we want to present this as a number input. In the node definition, the field is given `ui_type=UIType.Float`, which tells the UI to treat this as a `float` field. The connection validation logic will prevent connecting a collection to `CFG Scale` in this situation, because it is typed as `float`. The workaround is to disable validation from the settings to make this specific connection. A future improvement will resolve this. This also introduces better support for collection field types. Like polymorphics, collection types are parsed automatically by the client and do not need any specific type hints. Also like polymorphics, there is no support yet for direct input of collection types in the UI. - Disabling validation in workflow editor now displays the visual hints for valid connections, but lets you connect to anything. - Added `ui_order: int` to `InputField` and `OutputField`. The UI will use this, if present, to order fields in a node UI. See usage in `DenoiseLatents` for an example. - Updated the field colors - duplicate colors have just been lightened a bit. It's not perfect but it was a quick fix. - Field handles for collections are the same color as their single counterparts, but have a dark dot in the center of them. - Field handles for polymorphics are a rounded square with dot in the middle. - Removed all fields that just render `null` from `InputFieldRenderer`, replaced with a single fallback - Removed logic in `zValidatedWorkflow`, which checked for existence of node templates for each node in a workflow. This logic introduced a circular dependency, due to importing the global redux `store` in order to get the node templates within a zod schema. It's actually fine to just leave this out entirely; The case of a missing node template is handled by the UI. Fixing it otherwise would introduce a substantial headache. - Fixed the `ControlNetInvocation.control_model` field default, which was a string when it shouldn't have one.
546 lines
21 KiB
Python
546 lines
21 KiB
Python
# Invocations for ControlNet image preprocessors
|
|
# initial implementation by Gregg Helt, 2023
|
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
|
from builtins import bool, float
|
|
from typing import Dict, List, Literal, Optional, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from controlnet_aux import (
|
|
CannyDetector,
|
|
ContentShuffleDetector,
|
|
HEDdetector,
|
|
LeresDetector,
|
|
LineartAnimeDetector,
|
|
LineartDetector,
|
|
MediapipeFaceDetector,
|
|
MidasDetector,
|
|
MLSDdetector,
|
|
NormalBaeDetector,
|
|
OpenposeDetector,
|
|
PidiNetDetector,
|
|
SamDetector,
|
|
ZoeDetector,
|
|
)
|
|
from controlnet_aux.util import HWC3, ade_palette
|
|
from PIL import Image
|
|
from pydantic import BaseModel, Field, validator
|
|
|
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|
|
|
|
|
from ...backend.model_management import BaseModelType
|
|
from ..models.image import ImageCategory, ResourceOrigin
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
FieldDescriptions,
|
|
InputField,
|
|
Input,
|
|
InvocationContext,
|
|
OutputField,
|
|
UIType,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
|
|
|
|
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
|
CONTROLNET_RESIZE_VALUES = Literal[
|
|
"just_resize",
|
|
"crop_resize",
|
|
"fill_resize",
|
|
"just_resize_simple",
|
|
]
|
|
|
|
|
|
class ControlNetModelField(BaseModel):
|
|
"""ControlNet model field"""
|
|
|
|
model_name: str = Field(description="Name of the ControlNet model")
|
|
base_model: BaseModelType = Field(description="Base model")
|
|
|
|
|
|
class ControlField(BaseModel):
|
|
image: ImageField = Field(description="The control image")
|
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
|
begin_step_percent: float = Field(
|
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = Field(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
|
|
@validator("control_weight")
|
|
def validate_control_weight(cls, v):
|
|
"""Validate that all control weights in the valid range"""
|
|
if isinstance(v, list):
|
|
for i in v:
|
|
if i < -1 or i > 2:
|
|
raise ValueError("Control weights must be within -1 to 2 range")
|
|
else:
|
|
if v < -1 or v > 2:
|
|
raise ValueError("Control weights must be within -1 to 2 range")
|
|
return v
|
|
|
|
|
|
@invocation_output("control_output")
|
|
class ControlOutput(BaseInvocationOutput):
|
|
"""node output for ControlNet info"""
|
|
|
|
# Outputs
|
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
|
|
|
|
|
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
|
|
class ControlNetInvocation(BaseInvocation):
|
|
"""Collects ControlNet info to pass to other nodes"""
|
|
|
|
image: ImageField = InputField(description="The control image")
|
|
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
|
control_weight: Union[float, List[float]] = InputField(
|
|
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
|
)
|
|
begin_step_percent: float = InputField(
|
|
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = InputField(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
|
|
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
|
return ControlOutput(
|
|
control=ControlField(
|
|
image=self.image,
|
|
control_model=self.control_model,
|
|
control_weight=self.control_weight,
|
|
begin_step_percent=self.begin_step_percent,
|
|
end_step_percent=self.end_step_percent,
|
|
control_mode=self.control_mode,
|
|
resize_mode=self.resize_mode,
|
|
),
|
|
)
|
|
|
|
|
|
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
|
class ImageProcessorInvocation(BaseInvocation):
|
|
"""Base class for invocations that preprocess images for ControlNet"""
|
|
|
|
image: ImageField = InputField(description="The image to process")
|
|
|
|
def run_processor(self, image):
|
|
# superclass just passes through image without processing
|
|
return image
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
|
processed_image = self.run_processor(raw_image)
|
|
|
|
# currently can't see processed image in node UI without a showImage node,
|
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
|
image_dto = context.services.images.create(
|
|
image=processed_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.CONTROL,
|
|
session_id=context.graph_execution_state_id,
|
|
node_id=self.id,
|
|
is_intermediate=self.is_intermediate,
|
|
workflow=self.workflow,
|
|
)
|
|
|
|
"""Builds an ImageOutput and its ImageField"""
|
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
|
return ImageOutput(
|
|
image=processed_image_field,
|
|
# width=processed_image.width,
|
|
width=image_dto.width,
|
|
# height=processed_image.height,
|
|
height=image_dto.height,
|
|
# mode=processed_image.mode,
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"canny_image_processor",
|
|
title="Canny Processor",
|
|
tags=["controlnet", "canny"],
|
|
category="controlnet",
|
|
)
|
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Canny edge detection for ControlNet"""
|
|
|
|
low_threshold: int = InputField(
|
|
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
|
)
|
|
high_threshold: int = InputField(
|
|
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
|
)
|
|
|
|
def run_processor(self, image):
|
|
canny_processor = CannyDetector()
|
|
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"hed_image_processor",
|
|
title="HED (softedge) Processor",
|
|
tags=["controlnet", "hed", "softedge"],
|
|
category="controlnet",
|
|
)
|
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies HED edge detection to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
# safe not supported in controlnet_aux v0.0.3
|
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
|
|
|
def run_processor(self, image):
|
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = hed_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
# safe not supported in controlnet_aux v0.0.3
|
|
# safe=self.safe,
|
|
scribble=self.scribble,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"lineart_image_processor",
|
|
title="Lineart Processor",
|
|
tags=["controlnet", "lineart"],
|
|
category="controlnet",
|
|
)
|
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies line art processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
|
|
|
def run_processor(self, image):
|
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = lineart_processor(
|
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"lineart_anime_image_processor",
|
|
title="Lineart Anime Processor",
|
|
tags=["controlnet", "lineart", "anime"],
|
|
category="controlnet",
|
|
)
|
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies line art anime processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"openpose_image_processor",
|
|
title="Openpose Processor",
|
|
tags=["controlnet", "openpose", "pose"],
|
|
category="controlnet",
|
|
)
|
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Openpose processing to image"""
|
|
|
|
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = openpose_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
hand_and_face=self.hand_and_face,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"midas_depth_image_processor",
|
|
title="Midas Depth Processor",
|
|
tags=["controlnet", "midas"],
|
|
category="controlnet",
|
|
)
|
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Midas depth processing to image"""
|
|
|
|
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
|
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
|
|
|
def run_processor(self, image):
|
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = midas_processor(
|
|
image,
|
|
a=np.pi * self.a_mult,
|
|
bg_th=self.bg_th,
|
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
|
# depth_and_normal=self.depth_and_normal,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"normalbae_image_processor",
|
|
title="Normal BAE Processor",
|
|
tags=["controlnet"],
|
|
category="controlnet",
|
|
)
|
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies NormalBae processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = normalbae_processor(
|
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies MLSD processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
|
|
|
def run_processor(self, image):
|
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = mlsd_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
thr_v=self.thr_v,
|
|
thr_d=self.thr_d,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies PIDI processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
|
|
|
def run_processor(self, image):
|
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = pidi_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
safe=self.safe,
|
|
scribble=self.scribble,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"content_shuffle_image_processor",
|
|
title="Content Shuffle Processor",
|
|
tags=["controlnet", "contentshuffle"],
|
|
category="controlnet",
|
|
)
|
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies content shuffle processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
|
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
|
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
|
|
|
def run_processor(self, image):
|
|
content_shuffle_processor = ContentShuffleDetector()
|
|
processed_image = content_shuffle_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
h=self.h,
|
|
w=self.w,
|
|
f=self.f,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
|
@invocation(
|
|
"zoe_depth_image_processor",
|
|
title="Zoe (Depth) Processor",
|
|
tags=["controlnet", "zoe", "depth"],
|
|
category="controlnet",
|
|
)
|
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Zoe depth processing to image"""
|
|
|
|
def run_processor(self, image):
|
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = zoe_depth_processor(image)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"mediapipe_face_processor",
|
|
title="Mediapipe Face Processor",
|
|
tags=["controlnet", "mediapipe", "face"],
|
|
category="controlnet",
|
|
)
|
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies mediapipe face processing to image"""
|
|
|
|
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
|
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
|
|
|
def run_processor(self, image):
|
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
|
# so convert to RGB if needed
|
|
if image.mode == "RGBA":
|
|
image = image.convert("RGB")
|
|
mediapipe_face_processor = MediapipeFaceDetector()
|
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"leres_image_processor",
|
|
title="Leres (Depth) Processor",
|
|
tags=["controlnet", "leres", "depth"],
|
|
category="controlnet",
|
|
)
|
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies leres processing to image"""
|
|
|
|
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
|
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
|
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = leres_processor(
|
|
image,
|
|
thr_a=self.thr_a,
|
|
thr_b=self.thr_b,
|
|
boost=self.boost,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"tile_image_processor",
|
|
title="Tile Resample Processor",
|
|
tags=["controlnet", "tile"],
|
|
category="controlnet",
|
|
)
|
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|
"""Tile resampler processor"""
|
|
|
|
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
|
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
|
|
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
|
def tile_resample(
|
|
self,
|
|
np_img: np.ndarray,
|
|
res=512, # never used?
|
|
down_sampling_rate=1.0,
|
|
):
|
|
np_img = HWC3(np_img)
|
|
if down_sampling_rate < 1.1:
|
|
return np_img
|
|
H, W, C = np_img.shape
|
|
H = int(float(H) / float(down_sampling_rate))
|
|
W = int(float(W) / float(down_sampling_rate))
|
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
|
return np_img
|
|
|
|
def run_processor(self, img):
|
|
np_img = np.array(img, dtype=np.uint8)
|
|
processed_np_image = self.tile_resample(
|
|
np_img,
|
|
# res=self.tile_size,
|
|
down_sampling_rate=self.down_sampling_rate,
|
|
)
|
|
processed_image = Image.fromarray(processed_np_image)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"segment_anything_processor",
|
|
title="Segment Anything Processor",
|
|
tags=["controlnet", "segmentanything"],
|
|
category="controlnet",
|
|
)
|
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies segment anything processing to image"""
|
|
|
|
def run_processor(self, image):
|
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
|
)
|
|
np_img = np.array(image, dtype=np.uint8)
|
|
processed_image = segment_anything_processor(np_img)
|
|
return processed_image
|
|
|
|
|
|
class SamDetectorReproducibleColors(SamDetector):
|
|
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
|
# base class show_anns() method randomizes colors,
|
|
# which seems to also lead to non-reproducible image generation
|
|
# so using ADE20k color palette instead
|
|
def show_anns(self, anns: List[Dict]):
|
|
if len(anns) == 0:
|
|
return
|
|
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
|
h, w = anns[0]["segmentation"].shape
|
|
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
|
palette = ade_palette()
|
|
for i, ann in enumerate(sorted_anns):
|
|
m = ann["segmentation"]
|
|
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
|
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
|
ann_color = palette[i % len(palette)]
|
|
img[:, :] = ann_color
|
|
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
|
return np.array(final_img, dtype=np.uint8)
|