mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
10ada84404
This reverts commit f50f95a81d
.
693 lines
26 KiB
Python
693 lines
26 KiB
Python
import math
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional, TypedDict
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
|
from PIL.Image import Image as ImageType
|
|
from pydantic import validator
|
|
|
|
import invokeai.assets.fonts as font_assets
|
|
from invokeai.app.invocations.baseinvocation import (
|
|
BaseInvocation,
|
|
InputField,
|
|
InvocationContext,
|
|
OutputField,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
|
|
|
|
|
@invocation_output("face_mask_output")
|
|
class FaceMaskOutput(ImageOutput):
|
|
"""Base class for FaceMask output"""
|
|
|
|
mask: ImageField = OutputField(description="The output mask")
|
|
|
|
|
|
@invocation_output("face_off_output")
|
|
class FaceOffOutput(ImageOutput):
|
|
"""Base class for FaceOff Output"""
|
|
|
|
mask: ImageField = OutputField(description="The output mask")
|
|
x: int = OutputField(description="The x coordinate of the bounding box's left side")
|
|
y: int = OutputField(description="The y coordinate of the bounding box's top side")
|
|
|
|
|
|
class FaceResultData(TypedDict):
|
|
image: ImageType
|
|
mask: ImageType
|
|
x_center: float
|
|
y_center: float
|
|
mesh_width: int
|
|
mesh_height: int
|
|
|
|
|
|
class FaceResultDataWithId(FaceResultData):
|
|
face_id: int
|
|
|
|
|
|
class ExtractFaceData(TypedDict):
|
|
bounded_image: ImageType
|
|
bounded_mask: ImageType
|
|
x_min: int
|
|
y_min: int
|
|
x_max: int
|
|
y_max: int
|
|
|
|
|
|
class FaceMaskResult(TypedDict):
|
|
image: ImageType
|
|
mask: ImageType
|
|
|
|
|
|
def create_white_image(w: int, h: int) -> ImageType:
|
|
return Image.new("L", (w, h), color=255)
|
|
|
|
|
|
def create_black_image(w: int, h: int) -> ImageType:
|
|
return Image.new("L", (w, h), color=0)
|
|
|
|
|
|
FONT_SIZE = 32
|
|
FONT_STROKE_WIDTH = 4
|
|
|
|
|
|
def prepare_faces_list(
|
|
face_result_list: list[FaceResultData],
|
|
) -> list[FaceResultDataWithId]:
|
|
"""Deduplicates a list of faces, adding IDs to them."""
|
|
deduped_faces: list[FaceResultData] = []
|
|
|
|
if len(face_result_list) == 0:
|
|
return list()
|
|
|
|
for candidate in face_result_list:
|
|
should_add = True
|
|
candidate_x_center = candidate["x_center"]
|
|
candidate_y_center = candidate["y_center"]
|
|
for face in deduped_faces:
|
|
face_center_x = face["x_center"]
|
|
face_center_y = face["y_center"]
|
|
face_radius_w = face["mesh_width"] / 2
|
|
face_radius_h = face["mesh_height"] / 2
|
|
# Determine if the center of the candidate_face is inside the ellipse of the added face
|
|
# p < 1 -> Inside
|
|
# p = 1 -> Exactly on the ellipse
|
|
# p > 1 -> Outside
|
|
p = (math.pow((candidate_x_center - face_center_x), 2) / math.pow(face_radius_w, 2)) + (
|
|
math.pow((candidate_y_center - face_center_y), 2) / math.pow(face_radius_h, 2)
|
|
)
|
|
|
|
if p < 1: # Inside of the already-added face's radius
|
|
should_add = False
|
|
break
|
|
|
|
if should_add is True:
|
|
deduped_faces.append(candidate)
|
|
|
|
sorted_faces = sorted(deduped_faces, key=lambda x: x["y_center"])
|
|
sorted_faces = sorted(sorted_faces, key=lambda x: x["x_center"])
|
|
|
|
# add face_id for reference
|
|
sorted_faces_with_ids: list[FaceResultDataWithId] = []
|
|
face_id_counter = 0
|
|
for face in sorted_faces:
|
|
sorted_faces_with_ids.append(
|
|
FaceResultDataWithId(
|
|
**face,
|
|
face_id=face_id_counter,
|
|
)
|
|
)
|
|
face_id_counter += 1
|
|
|
|
return sorted_faces_with_ids
|
|
|
|
|
|
def generate_face_box_mask(
|
|
context: InvocationContext,
|
|
minimum_confidence: float,
|
|
x_offset: float,
|
|
y_offset: float,
|
|
pil_image: ImageType,
|
|
chunk_x_offset: int = 0,
|
|
chunk_y_offset: int = 0,
|
|
draw_mesh: bool = True,
|
|
check_bounds: bool = True,
|
|
) -> list[FaceResultData]:
|
|
result = []
|
|
mask_pil = None
|
|
|
|
# Convert the PIL image to a NumPy array.
|
|
np_image = np.array(pil_image, dtype=np.uint8)
|
|
|
|
# Check if the input image has four channels (RGBA).
|
|
if np_image.shape[2] == 4:
|
|
# Convert RGBA to RGB by removing the alpha channel.
|
|
np_image = np_image[:, :, :3]
|
|
|
|
# Create a FaceMesh object for face landmark detection and mesh generation.
|
|
face_mesh = FaceMesh(
|
|
max_num_faces=999,
|
|
min_detection_confidence=minimum_confidence,
|
|
min_tracking_confidence=minimum_confidence,
|
|
)
|
|
|
|
# Detect the face landmarks and mesh in the input image.
|
|
results = face_mesh.process(np_image)
|
|
|
|
# Check if any face is detected.
|
|
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
|
# Search for the face_id in the detected faces.
|
|
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
|
# Get the bounding box of the face mesh.
|
|
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
|
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
|
x_min, x_max = min(x_coordinates), max(x_coordinates)
|
|
y_min, y_max = min(y_coordinates), max(y_coordinates)
|
|
|
|
# Calculate the width and height of the face mesh.
|
|
mesh_width = int((x_max - x_min) * np_image.shape[1])
|
|
mesh_height = int((y_max - y_min) * np_image.shape[0])
|
|
|
|
# Get the center of the face.
|
|
x_center = np.mean([landmark.x * np_image.shape[1] for landmark in face_landmarks.landmark])
|
|
y_center = np.mean([landmark.y * np_image.shape[0] for landmark in face_landmarks.landmark])
|
|
|
|
face_landmark_points = np.array(
|
|
[
|
|
[landmark.x * np_image.shape[1], landmark.y * np_image.shape[0]]
|
|
for landmark in face_landmarks.landmark
|
|
]
|
|
)
|
|
|
|
# Apply the scaling offsets to the face landmark points with a multiplier.
|
|
scale_multiplier = 0.2
|
|
x_center = np.mean(face_landmark_points[:, 0])
|
|
y_center = np.mean(face_landmark_points[:, 1])
|
|
|
|
if draw_mesh:
|
|
x_scaled = face_landmark_points[:, 0] + scale_multiplier * x_offset * (
|
|
face_landmark_points[:, 0] - x_center
|
|
)
|
|
y_scaled = face_landmark_points[:, 1] + scale_multiplier * y_offset * (
|
|
face_landmark_points[:, 1] - y_center
|
|
)
|
|
|
|
convex_hull = cv2.convexHull(np.column_stack((x_scaled, y_scaled)).astype(np.int32))
|
|
|
|
# Generate a binary face mask using the face mesh.
|
|
mask_image = np.ones(np_image.shape[:2], dtype=np.uint8) * 255
|
|
cv2.fillConvexPoly(mask_image, convex_hull, 0)
|
|
|
|
# Convert the binary mask image to a PIL Image.
|
|
init_mask_pil = Image.fromarray(mask_image, mode="L")
|
|
w, h = init_mask_pil.size
|
|
mask_pil = create_white_image(w + chunk_x_offset, h + chunk_y_offset)
|
|
mask_pil.paste(init_mask_pil, (chunk_x_offset, chunk_y_offset))
|
|
|
|
left_side = x_center - mesh_width
|
|
right_side = x_center + mesh_width
|
|
top_side = y_center - mesh_height
|
|
bottom_side = y_center + mesh_height
|
|
im_width, im_height = pil_image.size
|
|
over_w = im_width * 0.1
|
|
over_h = im_height * 0.1
|
|
if not check_bounds or (
|
|
(left_side >= -over_w)
|
|
and (right_side < im_width + over_w)
|
|
and (top_side >= -over_h)
|
|
and (bottom_side < im_height + over_h)
|
|
):
|
|
x_center = float(x_center)
|
|
y_center = float(y_center)
|
|
face = FaceResultData(
|
|
image=pil_image,
|
|
mask=mask_pil or create_white_image(*pil_image.size),
|
|
x_center=x_center + chunk_x_offset,
|
|
y_center=y_center + chunk_y_offset,
|
|
mesh_width=mesh_width,
|
|
mesh_height=mesh_height,
|
|
)
|
|
|
|
result.append(face)
|
|
else:
|
|
context.services.logger.info("FaceTools --> Face out of bounds, ignoring.")
|
|
|
|
return result
|
|
|
|
|
|
def extract_face(
|
|
context: InvocationContext,
|
|
image: ImageType,
|
|
face: FaceResultData,
|
|
padding: int,
|
|
) -> ExtractFaceData:
|
|
mask = face["mask"]
|
|
center_x = face["x_center"]
|
|
center_y = face["y_center"]
|
|
mesh_width = face["mesh_width"]
|
|
mesh_height = face["mesh_height"]
|
|
|
|
# Determine the minimum size of the square crop
|
|
min_size = min(mask.width, mask.height)
|
|
|
|
# Calculate the crop boundaries for the output image and mask.
|
|
mesh_width += 128 + padding # add pixels to account for mask variance
|
|
mesh_height += 128 + padding # add pixels to account for mask variance
|
|
crop_size = min(
|
|
max(mesh_width, mesh_height, 128), min_size
|
|
) # Choose the smaller of the two (given value or face mask size)
|
|
if crop_size > 128:
|
|
crop_size = (crop_size + 7) // 8 * 8 # Ensure crop side is multiple of 8
|
|
|
|
# Calculate the actual crop boundaries within the bounds of the original image.
|
|
x_min = int(center_x - crop_size / 2)
|
|
y_min = int(center_y - crop_size / 2)
|
|
x_max = int(center_x + crop_size / 2)
|
|
y_max = int(center_y + crop_size / 2)
|
|
|
|
# Adjust the crop boundaries to stay within the original image's dimensions
|
|
if x_min < 0:
|
|
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
|
x_max -= x_min
|
|
x_min = 0
|
|
elif x_max > mask.width:
|
|
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
|
x_min -= x_max - mask.width
|
|
x_max = mask.width
|
|
|
|
if y_min < 0:
|
|
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
|
y_max -= y_min
|
|
y_min = 0
|
|
elif y_max > mask.height:
|
|
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
|
y_min -= y_max - mask.height
|
|
y_max = mask.height
|
|
|
|
# Ensure the crop is square and adjust the boundaries if needed
|
|
if x_max - x_min != crop_size:
|
|
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
|
diff = crop_size - (x_max - x_min)
|
|
x_min -= diff // 2
|
|
x_max += diff - diff // 2
|
|
|
|
if y_max - y_min != crop_size:
|
|
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
|
diff = crop_size - (y_max - y_min)
|
|
y_min -= diff // 2
|
|
y_max += diff - diff // 2
|
|
|
|
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
|
|
|
# Crop the output image to the specified size with the center of the face mesh as the center.
|
|
mask = mask.crop((x_min, y_min, x_max, y_max))
|
|
bounded_image = image.crop((x_min, y_min, x_max, y_max))
|
|
|
|
# blur mask edge by small radius
|
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=2))
|
|
|
|
return ExtractFaceData(
|
|
bounded_image=bounded_image,
|
|
bounded_mask=mask,
|
|
x_min=x_min,
|
|
y_min=y_min,
|
|
x_max=x_max,
|
|
y_max=y_max,
|
|
)
|
|
|
|
|
|
def get_faces_list(
|
|
context: InvocationContext,
|
|
image: ImageType,
|
|
should_chunk: bool,
|
|
minimum_confidence: float,
|
|
x_offset: float,
|
|
y_offset: float,
|
|
draw_mesh: bool = True,
|
|
) -> list[FaceResultDataWithId]:
|
|
result = []
|
|
|
|
# Generate the face box mask and get the center of the face.
|
|
if not should_chunk:
|
|
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
|
result = generate_face_box_mask(
|
|
context=context,
|
|
minimum_confidence=minimum_confidence,
|
|
x_offset=x_offset,
|
|
y_offset=y_offset,
|
|
pil_image=image,
|
|
chunk_x_offset=0,
|
|
chunk_y_offset=0,
|
|
draw_mesh=draw_mesh,
|
|
check_bounds=False,
|
|
)
|
|
if should_chunk or len(result) == 0:
|
|
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
|
width, height = image.size
|
|
image_chunks = []
|
|
x_offsets = []
|
|
y_offsets = []
|
|
result = []
|
|
|
|
# If width == height, there's nothing more we can do... otherwise...
|
|
if width > height:
|
|
# Landscape - slice the image horizontally
|
|
fx = 0.0
|
|
steps = int(width * 2 / height)
|
|
while fx <= (width - height):
|
|
x = int(fx)
|
|
image_chunks.append(image.crop((x, 0, x + height - 1, height - 1)))
|
|
x_offsets.append(x)
|
|
y_offsets.append(0)
|
|
fx += (width - height) / steps
|
|
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
|
elif height > width:
|
|
# Portrait - slice the image vertically
|
|
fy = 0.0
|
|
steps = int(height * 2 / width)
|
|
while fy <= (height - width):
|
|
y = int(fy)
|
|
image_chunks.append(image.crop((0, y, width - 1, y + width - 1)))
|
|
x_offsets.append(0)
|
|
y_offsets.append(y)
|
|
fy += (height - width) / steps
|
|
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
|
|
|
for idx in range(len(image_chunks)):
|
|
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
|
result = result + generate_face_box_mask(
|
|
context=context,
|
|
minimum_confidence=minimum_confidence,
|
|
x_offset=x_offset,
|
|
y_offset=y_offset,
|
|
pil_image=image_chunks[idx],
|
|
chunk_x_offset=x_offsets[idx],
|
|
chunk_y_offset=y_offsets[idx],
|
|
draw_mesh=draw_mesh,
|
|
)
|
|
|
|
if len(result) == 0:
|
|
# Give up
|
|
context.services.logger.warning(
|
|
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
|
)
|
|
|
|
all_faces = prepare_faces_list(result)
|
|
|
|
return all_faces
|
|
|
|
|
|
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
|
|
class FaceOffInvocation(BaseInvocation):
|
|
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
|
|
|
image: ImageField = InputField(description="Image for face detection")
|
|
face_id: int = InputField(
|
|
default=0,
|
|
ge=0,
|
|
description="The face ID to process, numbered from 0. Multiple faces not supported. Find a face's ID with FaceIdentifier node.",
|
|
)
|
|
minimum_confidence: float = InputField(
|
|
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
|
)
|
|
x_offset: float = InputField(default=0.0, description="X-axis offset of the mask")
|
|
y_offset: float = InputField(default=0.0, description="Y-axis offset of the mask")
|
|
padding: int = InputField(default=0, description="All-axis padding around the mask in pixels")
|
|
chunk: bool = InputField(
|
|
default=False,
|
|
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
|
)
|
|
|
|
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
|
|
all_faces = get_faces_list(
|
|
context=context,
|
|
image=image,
|
|
should_chunk=self.chunk,
|
|
minimum_confidence=self.minimum_confidence,
|
|
x_offset=self.x_offset,
|
|
y_offset=self.y_offset,
|
|
draw_mesh=True,
|
|
)
|
|
|
|
if len(all_faces) == 0:
|
|
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
|
return None
|
|
|
|
if self.face_id > len(all_faces) - 1:
|
|
context.services.logger.warning(
|
|
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
|
)
|
|
return None
|
|
|
|
face_data = extract_face(context=context, image=image, face=all_faces[self.face_id], padding=self.padding)
|
|
# Convert the input image to RGBA mode to ensure it has an alpha channel.
|
|
face_data["bounded_image"] = face_data["bounded_image"].convert("RGBA")
|
|
|
|
return face_data
|
|
|
|
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
result = self.faceoff(context=context, image=image)
|
|
|
|
if result is None:
|
|
result_image = image
|
|
result_mask = create_white_image(*image.size)
|
|
x = 0
|
|
y = 0
|
|
else:
|
|
result_image = result["bounded_image"]
|
|
result_mask = result["bounded_mask"]
|
|
x = result["x_min"]
|
|
y = result["y_min"]
|
|
|
|
image_dto = context.services.images.create(
|
|
image=result_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
workflow=self.workflow,
|
|
)
|
|
|
|
mask_dto = context.services.images.create(
|
|
image=result_mask,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.MASK,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
output = FaceOffOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
mask=ImageField(image_name=mask_dto.image_name),
|
|
x=x,
|
|
y=y,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
|
|
class FaceMaskInvocation(BaseInvocation):
|
|
"""Face mask creation using mediapipe face detection"""
|
|
|
|
image: ImageField = InputField(description="Image to face detect")
|
|
face_ids: str = InputField(
|
|
default="",
|
|
description="Comma-separated list of face ids to mask eg '0,2,7'. Numbered from 0. Leave empty to mask all. Find face IDs with FaceIdentifier node.",
|
|
)
|
|
minimum_confidence: float = InputField(
|
|
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
|
)
|
|
x_offset: float = InputField(default=0.0, description="Offset for the X-axis of the face mask")
|
|
y_offset: float = InputField(default=0.0, description="Offset for the Y-axis of the face mask")
|
|
chunk: bool = InputField(
|
|
default=False,
|
|
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
|
)
|
|
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
|
|
|
@validator("face_ids")
|
|
def validate_comma_separated_ints(cls, v) -> str:
|
|
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
|
if comma_separated_ints_regex.match(v) is None:
|
|
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
|
|
return v
|
|
|
|
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
|
|
all_faces = get_faces_list(
|
|
context=context,
|
|
image=image,
|
|
should_chunk=self.chunk,
|
|
minimum_confidence=self.minimum_confidence,
|
|
x_offset=self.x_offset,
|
|
y_offset=self.y_offset,
|
|
draw_mesh=True,
|
|
)
|
|
|
|
mask_pil = create_white_image(*image.size)
|
|
|
|
id_range = list(range(0, len(all_faces)))
|
|
ids_to_extract = id_range
|
|
if self.face_ids != "":
|
|
parsed_face_ids = [int(id) for id in self.face_ids.split(",")]
|
|
# get requested face_ids that are in range
|
|
intersected_face_ids = set(parsed_face_ids) & set(id_range)
|
|
|
|
if len(intersected_face_ids) == 0:
|
|
id_range_str = ",".join([str(id) for id in id_range])
|
|
context.services.logger.warning(
|
|
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
|
)
|
|
return FaceMaskResult(
|
|
image=image, # original image
|
|
mask=mask_pil, # white mask
|
|
)
|
|
|
|
ids_to_extract = list(intersected_face_ids)
|
|
|
|
for face_id in ids_to_extract:
|
|
face_data = extract_face(context=context, image=image, face=all_faces[face_id], padding=0)
|
|
face_mask_pil = face_data["bounded_mask"]
|
|
x_min = face_data["x_min"]
|
|
y_min = face_data["y_min"]
|
|
x_max = face_data["x_max"]
|
|
y_max = face_data["y_max"]
|
|
|
|
mask_pil.paste(
|
|
create_black_image(x_max - x_min, y_max - y_min),
|
|
box=(x_min, y_min),
|
|
mask=ImageOps.invert(face_mask_pil),
|
|
)
|
|
|
|
if self.invert_mask:
|
|
mask_pil = ImageOps.invert(mask_pil)
|
|
|
|
# Create an RGBA image with transparency
|
|
image = image.convert("RGBA")
|
|
|
|
return FaceMaskResult(
|
|
image=image,
|
|
mask=mask_pil,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
result = self.facemask(context=context, image=image)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=result["image"],
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
workflow=self.workflow,
|
|
)
|
|
|
|
mask_dto = context.services.images.create(
|
|
image=result["mask"],
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.MASK,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
output = FaceMaskOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
mask=ImageField(image_name=mask_dto.image_name),
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
@invocation(
|
|
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
|
|
)
|
|
class FaceIdentifierInvocation(BaseInvocation):
|
|
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
|
|
|
image: ImageField = InputField(description="Image to face detect")
|
|
minimum_confidence: float = InputField(
|
|
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
|
)
|
|
chunk: bool = InputField(
|
|
default=False,
|
|
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
|
)
|
|
|
|
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
|
|
image = image.copy()
|
|
|
|
all_faces = get_faces_list(
|
|
context=context,
|
|
image=image,
|
|
should_chunk=self.chunk,
|
|
minimum_confidence=self.minimum_confidence,
|
|
x_offset=0,
|
|
y_offset=0,
|
|
draw_mesh=False,
|
|
)
|
|
|
|
# Note - font may be found either in the repo if running an editable install, or in the venv if running a package install
|
|
font_path = [x for x in [Path(y, "inter/Inter-Regular.ttf") for y in font_assets.__path__] if x.exists()]
|
|
font = ImageFont.truetype(font_path[0].as_posix(), FONT_SIZE)
|
|
|
|
# Paste face IDs on the output image
|
|
draw = ImageDraw.Draw(image)
|
|
for face in all_faces:
|
|
x_coord = face["x_center"]
|
|
y_coord = face["y_center"]
|
|
text = str(face["face_id"])
|
|
# get bbox of the text so we can center the id on the face
|
|
_, _, bbox_w, bbox_h = draw.textbbox(xy=(0, 0), text=text, font=font, stroke_width=FONT_STROKE_WIDTH)
|
|
x = x_coord - bbox_w / 2
|
|
y = y_coord - bbox_h / 2
|
|
draw.text(
|
|
xy=(x, y),
|
|
text=str(text),
|
|
fill=(255, 255, 255, 255),
|
|
font=font,
|
|
stroke_width=FONT_STROKE_WIDTH,
|
|
stroke_fill=(0, 0, 0, 255),
|
|
)
|
|
|
|
# Create an RGBA image with transparency
|
|
image = image.convert("RGBA")
|
|
|
|
return image
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
result_image = self.faceidentifier(context=context, image=image)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=result_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
workflow=self.workflow,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|