From ccdecf21a3dda1dc051eed59a5567aa60406f2f2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:41:17 +1000 Subject: [PATCH] tidy(nodes): cnet processors - Set `self._context=context` instead of changing the type signature of `run_processor` - Tidy a few typing things --- .../controlnet_image_processors.py | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e533583829..1e4ad672bf 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -132,7 +132,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image @@ -141,9 +141,10 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): return context.images.get_pil(self.image.image_name, "RGB") def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(raw_image, context) + processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery @@ -184,7 +185,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): # Keep alpha channel for Canny processing to detect edges of transparent areas return context.images.get_pil(self.image.image_name, "RGBA") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: processed_image = get_canny_edges( image, self.low_threshold, @@ -211,7 +212,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: hed_processor = HEDProcessor() processed_image = hed_processor.run( image, @@ -238,7 +239,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: lineart_processor = LineartProcessor() processed_image = lineart_processor.run( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse @@ -259,7 +260,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: processor = LineartAnimeProcessor() processed_image = processor.run( image, @@ -286,7 +287,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( @@ -314,9 +315,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") - processed_image: Image.Image = normalbae_processor( + processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @@ -333,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -356,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -384,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -408,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -429,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -457,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -499,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image: - np_img = np.array(img, dtype=np.uint8) + def run_processor(self, image: Image.Image) -> Image.Image: + np_img = np.array(image, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, # res=self.tile_size, @@ -523,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -569,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -604,13 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + def run_processor(self, image: Image.Image) -> Image.Image: def loader(model_path: Path): return DepthAnythingDetector.load_model( model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() ) - with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: + with self._context.models.load_and_cache_model( + source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader + ) as model: depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) processed_image = depth_anything_detector(image=image, resolution=self.resolution) return processed_image @@ -631,10 +634,9 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: - mm = context.models - onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) - onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + def run_processor(self, image: Image.Image) -> Image.Image: + onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose(