From de4064bdac1d80ec3aac496b3727854094a42d8f Mon Sep 17 00:00:00 2001 From: user1 Date: Sun, 25 Jun 2023 12:38:17 -0700 Subject: [PATCH] Fixed problem with with non-reproducible results from ControlNet SegmentAnything preprocessor. Cause was controlnet_aux randomization of segmentation coloring, which seems to lead to some randomization of resulting images using ControlNet seg model. Switched to using deterministic ADE20K color palette instead, which solved the problem. --- .../controlnet_image_processors.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 0fdfc12905..c5777284a5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -4,7 +4,7 @@ from builtins import float, bool import numpy as np -from typing import Literal, Optional, Union, List +from typing import Literal, Optional, Union, List, Dict from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field, validator @@ -32,6 +32,9 @@ from controlnet_aux import ( SamDetector, ) +from controlnet_aux.util import ade_palette + + from .image import ImageOutput, PILInvocationConfig CONTROLNET_DEFAULT_MODELS = [ @@ -465,6 +468,35 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation # fmt: on def run_processor(self, image): - segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") processed_image = segment_anything_processor(image) return processed_image + +class SamDetectorReproducibleColors(SamDetector): + + # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image + # base class show_anns() method randomizes colors, + # which seems to also lead to non-reproducible image generation + # so using ADE20k color palette instead + def show_anns(self, anns: List[Dict]): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + print("number of annotations: ", len(sorted_anns)) + print("type of annotations: ", type(sorted_anns)) + palette = ade_palette() + for i, ann in enumerate(sorted_anns): + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + # doing modulo just in case number of annotated regions exceeds number of colors in palette + ann_color = palette[i % len(palette)] + print(ann_color) + img[:, :, 0] = ann_color[0] + img[:, :, 1] = ann_color[1] + img[:, :, 2] = ann_color[2] + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) + + return np.array(final_img, dtype=np.uint8)