mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c42d692ea6
* chore: bump pydantic to 2.5.2 This release fixes pydantic/pydantic#8175 and allows us to use `JsonValue` * fix(ui): exclude public/en.json from prettier config * fix(workflow_records): fix SQLite workflow insertion to ignore duplicates * feat(backend): update workflows handling Update workflows handling for Workflow Library. **Updated Workflow Storage** "Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB. This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost. **Updated Workflow Handling in Nodes** Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically. A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`. **Database Migrations** Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details. The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator. **Other/Support Changes** - Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow. - Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow. - Add route to get the workflow from an image - Add CRUD service/routes for the library workflows - `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB) * feat(ui): updated workflow handling (WIP) Clientside updates for the backend workflow changes. Includes roughed-out workflow library UI. * feat: revert SQLiteMigrator class Will pursue this in a separate PR. * feat(nodes): do not overwrite custom node module names Use a different, simpler method to detect if a node is custom. * feat(nodes): restore WithWorkflow as no-op class This class is deprecated and no longer needed. Set its workflow attr value to None (meaning it is now a no-op), and issue a warning when an invocation subclasses it. * fix(nodes): fix get_workflow from queue item dict func * feat(backend): add WorkflowRecordListItemDTO This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl * chore(ui): typegen * feat(ui): add workflow loading, deleting to workflow library UI * feat(ui): workflow library pagination button styles * wip * feat: workflow library WIP - Save to library - Duplicate - Filter/sort - UI/queries * feat: workflow library - system graphs - wip * feat(backend): sync system workflows to db * fix: merge conflicts * feat: simplify default workflows - Rename "system" -> "default" - Simplify syncing logic - Update UI to match * feat(workflows): update default workflows - Update TextToImage_SD15 - Add TextToImage_SDXL - Add README * feat(ui): refine workflow list UI * fix(workflow_records): typo * fix(tests): fix tests * feat(ui): clean up workflow library hooks * fix(db): fix mis-ordered db cleanup step It was happening before pruning queue items - should happen afterwards, else you have to restart the app again to free disk space made available by the pruning. * feat(ui): tweak reset workflow editor translations * feat(ui): split out workflow redux state The `nodes` slice is a rather complicated slice. Removing `workflow` makes it a bit more reasonable. Also helps to flatten state out a bit. * docs: update default workflows README * fix: tidy up unused files, unrelated changes * fix(backend): revert unrelated service organisational changes * feat(backend): workflow_records.get_many arg "filter_text" -> "query" * feat(ui): use custom hook in current image buttons Already in use elsewhere, forgot to use it here. * fix(ui): remove commented out property * fix(ui): fix workflow loading - Different handling for loading from library vs external - Fix bug where only nodes and edges loaded * fix(ui): fix save/save-as workflow naming * fix(ui): fix circular dependency * fix(db): fix bug with releasing without lock in db.clean() * fix(db): remove extraneous lock * chore: bump ruff * fix(workflow_records): default `category` to `WorkflowCategory.User` This allows old workflows to validate when reading them from the db or image files. * hide workflow library buttons if feature is disabled --------- Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
594 lines
23 KiB
Python
594 lines
23 KiB
Python
# Invocations for ControlNet image preprocessors
|
|
# initial implementation by Gregg Helt, 2023
|
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
|
from builtins import bool, float
|
|
from typing import Dict, List, Literal, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from controlnet_aux import (
|
|
CannyDetector,
|
|
ContentShuffleDetector,
|
|
HEDdetector,
|
|
LeresDetector,
|
|
LineartAnimeDetector,
|
|
LineartDetector,
|
|
MediapipeFaceDetector,
|
|
MidasDetector,
|
|
MLSDdetector,
|
|
NormalBaeDetector,
|
|
OpenposeDetector,
|
|
PidiNetDetector,
|
|
SamDetector,
|
|
ZoeDetector,
|
|
)
|
|
from controlnet_aux.util import HWC3, ade_palette
|
|
from PIL import Image
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
|
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
from invokeai.app.shared.fields import FieldDescriptions
|
|
|
|
from ...backend.model_management import BaseModelType
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
Input,
|
|
InputField,
|
|
InvocationContext,
|
|
OutputField,
|
|
WithMetadata,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
|
|
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
|
CONTROLNET_RESIZE_VALUES = Literal[
|
|
"just_resize",
|
|
"crop_resize",
|
|
"fill_resize",
|
|
"just_resize_simple",
|
|
]
|
|
|
|
|
|
class ControlNetModelField(BaseModel):
|
|
"""ControlNet model field"""
|
|
|
|
model_name: str = Field(description="Name of the ControlNet model")
|
|
base_model: BaseModelType = Field(description="Base model")
|
|
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
|
|
class ControlField(BaseModel):
|
|
image: ImageField = Field(description="The control image")
|
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
|
begin_step_percent: float = Field(
|
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = Field(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
|
|
@field_validator("control_weight")
|
|
def validate_control_weight(cls, v):
|
|
"""Validate that all control weights in the valid range"""
|
|
if isinstance(v, list):
|
|
for i in v:
|
|
if i < -1 or i > 2:
|
|
raise ValueError("Control weights must be within -1 to 2 range")
|
|
else:
|
|
if v < -1 or v > 2:
|
|
raise ValueError("Control weights must be within -1 to 2 range")
|
|
return v
|
|
|
|
|
|
@invocation_output("control_output")
|
|
class ControlOutput(BaseInvocationOutput):
|
|
"""node output for ControlNet info"""
|
|
|
|
# Outputs
|
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
|
|
|
|
|
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.0")
|
|
class ControlNetInvocation(BaseInvocation):
|
|
"""Collects ControlNet info to pass to other nodes"""
|
|
|
|
image: ImageField = InputField(description="The control image")
|
|
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
|
control_weight: Union[float, List[float]] = InputField(
|
|
default=1.0, description="The weight given to the ControlNet"
|
|
)
|
|
begin_step_percent: float = InputField(
|
|
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = InputField(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
|
|
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
|
return ControlOutput(
|
|
control=ControlField(
|
|
image=self.image,
|
|
control_model=self.control_model,
|
|
control_weight=self.control_weight,
|
|
begin_step_percent=self.begin_step_percent,
|
|
end_step_percent=self.end_step_percent,
|
|
control_mode=self.control_mode,
|
|
resize_mode=self.resize_mode,
|
|
),
|
|
)
|
|
|
|
|
|
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
|
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
|
"""Base class for invocations that preprocess images for ControlNet"""
|
|
|
|
image: ImageField = InputField(description="The image to process")
|
|
|
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
|
# superclass just passes through image without processing
|
|
return image
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
|
processed_image = self.run_processor(raw_image)
|
|
|
|
# currently can't see processed image in node UI without a showImage node,
|
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
|
image_dto = context.services.images.create(
|
|
image=processed_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.CONTROL,
|
|
session_id=context.graph_execution_state_id,
|
|
node_id=self.id,
|
|
is_intermediate=self.is_intermediate,
|
|
metadata=self.metadata,
|
|
workflow=context.workflow,
|
|
)
|
|
|
|
"""Builds an ImageOutput and its ImageField"""
|
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
|
return ImageOutput(
|
|
image=processed_image_field,
|
|
# width=processed_image.width,
|
|
width=image_dto.width,
|
|
# height=processed_image.height,
|
|
height=image_dto.height,
|
|
# mode=processed_image.mode,
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"canny_image_processor",
|
|
title="Canny Processor",
|
|
tags=["controlnet", "canny"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Canny edge detection for ControlNet"""
|
|
|
|
low_threshold: int = InputField(
|
|
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
|
)
|
|
high_threshold: int = InputField(
|
|
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
|
)
|
|
|
|
def run_processor(self, image):
|
|
canny_processor = CannyDetector()
|
|
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"hed_image_processor",
|
|
title="HED (softedge) Processor",
|
|
tags=["controlnet", "hed", "softedge"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies HED edge detection to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
# safe not supported in controlnet_aux v0.0.3
|
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
|
|
|
def run_processor(self, image):
|
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = hed_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
# safe not supported in controlnet_aux v0.0.3
|
|
# safe=self.safe,
|
|
scribble=self.scribble,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"lineart_image_processor",
|
|
title="Lineart Processor",
|
|
tags=["controlnet", "lineart"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies line art processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
|
|
|
def run_processor(self, image):
|
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = lineart_processor(
|
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"lineart_anime_image_processor",
|
|
title="Lineart Anime Processor",
|
|
tags=["controlnet", "lineart", "anime"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies line art anime processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"openpose_image_processor",
|
|
title="Openpose Processor",
|
|
tags=["controlnet", "openpose", "pose"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Openpose processing to image"""
|
|
|
|
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = openpose_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
hand_and_face=self.hand_and_face,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"midas_depth_image_processor",
|
|
title="Midas Depth Processor",
|
|
tags=["controlnet", "midas"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Midas depth processing to image"""
|
|
|
|
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
|
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
|
|
|
def run_processor(self, image):
|
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = midas_processor(
|
|
image,
|
|
a=np.pi * self.a_mult,
|
|
bg_th=self.bg_th,
|
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
|
# depth_and_normal=self.depth_and_normal,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"normalbae_image_processor",
|
|
title="Normal BAE Processor",
|
|
tags=["controlnet"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies NormalBae processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = normalbae_processor(
|
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
|
)
|
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies MLSD processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
|
|
|
def run_processor(self, image):
|
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = mlsd_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
thr_v=self.thr_v,
|
|
thr_d=self.thr_d,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
|
)
|
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies PIDI processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
|
|
|
def run_processor(self, image):
|
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = pidi_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
safe=self.safe,
|
|
scribble=self.scribble,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"content_shuffle_image_processor",
|
|
title="Content Shuffle Processor",
|
|
tags=["controlnet", "contentshuffle"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies content shuffle processing to image"""
|
|
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
|
|
|
def run_processor(self, image):
|
|
content_shuffle_processor = ContentShuffleDetector()
|
|
processed_image = content_shuffle_processor(
|
|
image,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
h=self.h,
|
|
w=self.w,
|
|
f=self.f,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
|
@invocation(
|
|
"zoe_depth_image_processor",
|
|
title="Zoe (Depth) Processor",
|
|
tags=["controlnet", "zoe", "depth"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies Zoe depth processing to image"""
|
|
|
|
def run_processor(self, image):
|
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = zoe_depth_processor(image)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"mediapipe_face_processor",
|
|
title="Mediapipe Face Processor",
|
|
tags=["controlnet", "mediapipe", "face"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies mediapipe face processing to image"""
|
|
|
|
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
|
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
|
|
|
def run_processor(self, image):
|
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
|
# so convert to RGB if needed
|
|
if image.mode == "RGBA":
|
|
image = image.convert("RGB")
|
|
mediapipe_face_processor = MediapipeFaceDetector()
|
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"leres_image_processor",
|
|
title="Leres (Depth) Processor",
|
|
tags=["controlnet", "leres", "depth"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies leres processing to image"""
|
|
|
|
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
|
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
|
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
|
|
|
def run_processor(self, image):
|
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
|
processed_image = leres_processor(
|
|
image,
|
|
thr_a=self.thr_a,
|
|
thr_b=self.thr_b,
|
|
boost=self.boost,
|
|
detect_resolution=self.detect_resolution,
|
|
image_resolution=self.image_resolution,
|
|
)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"tile_image_processor",
|
|
title="Tile Resample Processor",
|
|
tags=["controlnet", "tile"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|
"""Tile resampler processor"""
|
|
|
|
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
|
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
|
|
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
|
def tile_resample(
|
|
self,
|
|
np_img: np.ndarray,
|
|
res=512, # never used?
|
|
down_sampling_rate=1.0,
|
|
):
|
|
np_img = HWC3(np_img)
|
|
if down_sampling_rate < 1.1:
|
|
return np_img
|
|
H, W, C = np_img.shape
|
|
H = int(float(H) / float(down_sampling_rate))
|
|
W = int(float(W) / float(down_sampling_rate))
|
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
|
return np_img
|
|
|
|
def run_processor(self, img):
|
|
np_img = np.array(img, dtype=np.uint8)
|
|
processed_np_image = self.tile_resample(
|
|
np_img,
|
|
# res=self.tile_size,
|
|
down_sampling_rate=self.down_sampling_rate,
|
|
)
|
|
processed_image = Image.fromarray(processed_np_image)
|
|
return processed_image
|
|
|
|
|
|
@invocation(
|
|
"segment_anything_processor",
|
|
title="Segment Anything Processor",
|
|
tags=["controlnet", "segmentanything"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|
"""Applies segment anything processing to image"""
|
|
|
|
def run_processor(self, image):
|
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
|
)
|
|
np_img = np.array(image, dtype=np.uint8)
|
|
processed_image = segment_anything_processor(np_img)
|
|
return processed_image
|
|
|
|
|
|
class SamDetectorReproducibleColors(SamDetector):
|
|
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
|
# base class show_anns() method randomizes colors,
|
|
# which seems to also lead to non-reproducible image generation
|
|
# so using ADE20k color palette instead
|
|
def show_anns(self, anns: List[Dict]):
|
|
if len(anns) == 0:
|
|
return
|
|
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
|
h, w = anns[0]["segmentation"].shape
|
|
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
|
palette = ade_palette()
|
|
for i, ann in enumerate(sorted_anns):
|
|
m = ann["segmentation"]
|
|
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
|
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
|
ann_color = palette[i % len(palette)]
|
|
img[:, :] = ann_color
|
|
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
|
return np.array(final_img, dtype=np.uint8)
|
|
|
|
|
|
@invocation(
|
|
"color_map_image_processor",
|
|
title="Color Map Processor",
|
|
tags=["controlnet"],
|
|
category="controlnet",
|
|
version="1.2.0",
|
|
)
|
|
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|
"""Generates a color map from the provided image"""
|
|
|
|
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
|
|
|
def run_processor(self, image: Image.Image):
|
|
image = image.convert("RGB")
|
|
np_image = np.array(image, dtype=np.uint8)
|
|
height, width = np_image.shape[:2]
|
|
|
|
width_tile_size = min(self.color_map_tile_size, width)
|
|
height_tile_size = min(self.color_map_tile_size, height)
|
|
|
|
color_map = cv2.resize(
|
|
np_image,
|
|
(width // width_tile_size, height // height_tile_size),
|
|
interpolation=cv2.INTER_CUBIC,
|
|
)
|
|
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
|
color_map = Image.fromarray(color_map)
|
|
return color_map
|