InvokeAI/invokeai/backend/vto_workflow/extract_channel.py
Ryan Dick 36d72baaaa WIP
2024-07-25 14:14:12 -04:00

58 lines
1.6 KiB
Python

from enum import Enum
import cv2
import numpy as np
import numpy.typing as npt
class ImageChannel(Enum):
RGB_R = "RGB_R"
RGB_G = "RGB_G"
RGB_B = "RGB_B"
LAB_L = "LAB_L"
LAB_A = "LAB_A"
LAB_B = "LAB_B"
HSV_H = "HSV_H"
HSV_S = "HSV_S"
HSV_V = "HSV_V"
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
"""Extract a channel from an image.
Args:
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
channel (ImageChannel): The channel to extract.
Returns:
np.ndarray: Shape (H, W) of dtype uint8.
"""
if channel == ImageChannel.RGB_R:
return image[:, :, 0]
elif channel == ImageChannel.RGB_G:
return image[:, :, 1]
elif channel == ImageChannel.RGB_B:
return image[:, :, 2]
elif channel == ImageChannel.LAB_L:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 0]
elif channel == ImageChannel.LAB_A:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 1]
elif channel == ImageChannel.LAB_B:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 2]
elif channel == ImageChannel.HSV_H:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 0]
elif channel == ImageChannel.HSV_S:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 1]
elif channel == ImageChannel.HSV_V:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 2]
else:
raise ValueError(f"Unknown channel: {channel}")