Fixed problem with with non-reproducible results from ControlNet SegmentAnything preprocessor. Cause was controlnet_aux randomization of segmentation coloring, which seems to lead to some randomization of resulting images using ControlNet seg model. Switched to using deterministic ADE20K color palette instead, which solved the problem.

This commit is contained in:
user1 2023-06-25 12:38:17 -07:00
parent 10c3753d7f
commit de4064bdac

View File

@ -4,7 +4,7 @@
from builtins import float, bool
import numpy as np
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator
@ -32,6 +32,9 @@ from controlnet_aux import (
SamDetector,
)
from controlnet_aux.util import ade_palette
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
@ -465,6 +468,35 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation
# fmt: on
def run_processor(self, image):
segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
processed_image = segment_anything_processor(image)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
print("number of annotations: ", len(sorted_anns))
print("type of annotations: ", type(sorted_anns))
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
print(ann_color)
img[:, :, 0] = ann_color[0]
img[:, :, 1] = ann_color[1]
img[:, :, 2] = ann_color[2]
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)