mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(nodes): cnet processors
- Set `self._context=context` instead of changing the type signature of `run_processor` - Tidy a few typing things
This commit is contained in:
parent
b124440023
commit
ccdecf21a3
@ -132,7 +132,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -141,9 +141,10 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return context.images.get_pil(self.image.image_name, "RGB")
|
return context.images.get_pil(self.image.image_name, "RGB")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
self._context = context
|
||||||
raw_image = self.load_image(context)
|
raw_image = self.load_image(context)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image, context)
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
# currently can't see processed image in node UI without a showImage node,
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
@ -184,7 +185,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
processed_image = get_canny_edges(
|
processed_image = get_canny_edges(
|
||||||
image,
|
image,
|
||||||
self.low_threshold,
|
self.low_threshold,
|
||||||
@ -211,7 +212,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
hed_processor = HEDProcessor()
|
hed_processor = HEDProcessor()
|
||||||
processed_image = hed_processor.run(
|
processed_image = hed_processor.run(
|
||||||
image,
|
image,
|
||||||
@ -238,7 +239,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
lineart_processor = LineartProcessor()
|
lineart_processor = LineartProcessor()
|
||||||
processed_image = lineart_processor.run(
|
processed_image = lineart_processor.run(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||||
@ -259,7 +260,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
processor = LineartAnimeProcessor()
|
processor = LineartAnimeProcessor()
|
||||||
processed_image = processor.run(
|
processed_image = processor.run(
|
||||||
image,
|
image,
|
||||||
@ -286,7 +287,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(
|
||||||
@ -314,9 +315,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image: Image.Image = normalbae_processor(
|
processed_image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -333,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image,
|
image,
|
||||||
@ -356,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image,
|
image,
|
||||||
@ -384,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(
|
processed_image = content_shuffle_processor(
|
||||||
image,
|
image,
|
||||||
@ -408,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -429,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(
|
processed_image = mediapipe_face_processor(
|
||||||
image,
|
image,
|
||||||
@ -457,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image,
|
||||||
@ -499,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(
|
processed_np_image = self.tile_resample(
|
||||||
np_img,
|
np_img,
|
||||||
# res=self.tile_size,
|
# res=self.tile_size,
|
||||||
@ -523,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||||
@ -569,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = np_image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
@ -604,13 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
def loader(model_path: Path):
|
def loader(model_path: Path):
|
||||||
return DepthAnythingDetector.load_model(
|
return DepthAnythingDetector.load_model(
|
||||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
with self._context.models.load_and_cache_model(
|
||||||
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||||
|
) as model:
|
||||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -631,10 +634,9 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
draw_hands: bool = InputField(default=False)
|
draw_hands: bool = InputField(default=False)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mm = context.models
|
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
|
||||||
|
|
||||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
processed_image = dw_openpose(
|
processed_image = dw_openpose(
|
||||||
|
Loading…
Reference in New Issue
Block a user