diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py index 920bca081b..67fd7bb43e 100644 --- a/invokeai/app/util/controlnet_utils.py +++ b/invokeai/app/util/controlnet_utils.py @@ -107,7 +107,6 @@ def np_img_resize( w: int, device: torch.device = torch.device('cpu') ): - print("in np_img_resize") # if 'inpaint' in module: # np_img = np_img.astype(np.float32) # else: @@ -192,7 +191,6 @@ def np_img_resize( # if resize_mode == external_code.ResizeMode.RESIZE: if resize_mode == "just_resize": # RESIZE - print("just resizing") np_img = high_quality_resize(np_img, (w, h)) np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img @@ -207,7 +205,6 @@ def np_img_resize( # if resize_mode == external_code.ResizeMode.OUTER_FIT: if resize_mode == "fill_resize": # OUTER_FIT - print("fill + resizing") k = min(k0, k1) borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0) high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype) @@ -224,7 +221,6 @@ def np_img_resize( np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img else: # resize_mode == "crop_resize" (INNER_FIT) - print("crop + resizing") k = max(k0, k1) np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) new_h, new_w, _ = np_img.shape @@ -233,3 +229,60 @@ def np_img_resize( np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w] np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img + +def prepare_control_image( + # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]] + # but now should be able to assume that image is a single PIL.Image, which simplifies things + image: Image, + # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? + # latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width) + width=512, # should be 8 * latent.shape[3] + height=512, # should be 8 * latent height[2] + # batch_size=1, # currently no batching + # num_images_per_prompt=1, # currently only single image + device="cuda", + dtype=torch.float16, + do_classifier_free_guidance=True, + control_mode="balanced", + resize_mode="just_resize_simple", +): + # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out + if (resize_mode == "just_resize_simple" or + resize_mode == "crop_resize_simple" or + resize_mode == "fill_resize_simple"): + image = image.convert("RGB") + if (resize_mode == "just_resize_simple"): + image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + elif (resize_mode == "crop_resize_simple"): # not yet implemented + pass + elif (resize_mode == "fill_resize_simple"): # not yet implemented + pass + nimage = np.array(image) + nimage = nimage[None, :] + nimage = np.concatenate([nimage], axis=0) + # normalizing RGB values to [0,1] range (in PIL.Image they are [0-255]) + nimage = np.array(nimage).astype(np.float32) / 255.0 + nimage = nimage.transpose(0, 3, 1, 2) + timage = torch.from_numpy(nimage) + + # use fancy lvmin controlnet resizing + elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"): + nimage = np.array(image) + timage, nimage = np_img_resize( + np_img=nimage, + resize_mode=resize_mode, + h=height, + w=width, + # device=torch.device('cpu') + device=device, + ) + else: + pass + print("ERROR: invalid resize_mode ==> ", resize_mode) + exit(1) + + timage = timage.to(device=device, dtype=dtype) + cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") + if do_classifier_free_guidance and not cfg_injection: + timage = torch.cat([timage] * 2) + return timage