mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixed problem with with non-reproducible results from ControlNet SegmentAnything preprocessor. Cause was controlnet_aux randomization of segmentation coloring, which seems to lead to some randomization of resulting images using ControlNet seg model. Switched to using deterministic ADE20K color palette instead, which solved the problem.
This commit is contained in:
parent
10c3753d7f
commit
de4064bdac
@ -4,7 +4,7 @@
|
||||
from builtins import float, bool
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from typing import Literal, Optional, Union, List, Dict
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
@ -32,6 +32,9 @@ from controlnet_aux import (
|
||||
SamDetector,
|
||||
)
|
||||
|
||||
from controlnet_aux.util import ade_palette
|
||||
|
||||
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
@ -465,6 +468,35 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
processed_image = segment_anything_processor(image)
|
||||
return processed_image
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
h, w = anns[0]['segmentation'].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
print("number of annotations: ", len(sorted_anns))
|
||||
print("type of annotations: ", type(sorted_anns))
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann['segmentation']
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
print(ann_color)
|
||||
img[:, :, 0] = ann_color[0]
|
||||
img[:, :, 1] = ann_color[1]
|
||||
img[:, :, 2] = ann_color[2]
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
Loading…
Reference in New Issue
Block a user