mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added revised prepare_control_image() that leverages lvmin high quality resizing
This commit is contained in:
parent
6cb9167a1b
commit
b8e0810ed1
@ -107,7 +107,6 @@ def np_img_resize(
|
|||||||
w: int,
|
w: int,
|
||||||
device: torch.device = torch.device('cpu')
|
device: torch.device = torch.device('cpu')
|
||||||
):
|
):
|
||||||
print("in np_img_resize")
|
|
||||||
# if 'inpaint' in module:
|
# if 'inpaint' in module:
|
||||||
# np_img = np_img.astype(np.float32)
|
# np_img = np_img.astype(np.float32)
|
||||||
# else:
|
# else:
|
||||||
@ -192,7 +191,6 @@ def np_img_resize(
|
|||||||
|
|
||||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||||
if resize_mode == "just_resize": # RESIZE
|
if resize_mode == "just_resize": # RESIZE
|
||||||
print("just resizing")
|
|
||||||
np_img = high_quality_resize(np_img, (w, h))
|
np_img = high_quality_resize(np_img, (w, h))
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
@ -207,7 +205,6 @@ def np_img_resize(
|
|||||||
|
|
||||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||||
if resize_mode == "fill_resize": # OUTER_FIT
|
if resize_mode == "fill_resize": # OUTER_FIT
|
||||||
print("fill + resizing")
|
|
||||||
k = min(k0, k1)
|
k = min(k0, k1)
|
||||||
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
|
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
|
||||||
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
|
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
|
||||||
@ -224,7 +221,6 @@ def np_img_resize(
|
|||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||||
print("crop + resizing")
|
|
||||||
k = max(k0, k1)
|
k = max(k0, k1)
|
||||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
new_h, new_w, _ = np_img.shape
|
new_h, new_w, _ = np_img.shape
|
||||||
@ -233,3 +229,60 @@ def np_img_resize(
|
|||||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
|
|
||||||
|
def prepare_control_image(
|
||||||
|
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||||
|
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||||
|
image: Image,
|
||||||
|
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||||
|
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
||||||
|
width=512, # should be 8 * latent.shape[3]
|
||||||
|
height=512, # should be 8 * latent height[2]
|
||||||
|
# batch_size=1, # currently no batching
|
||||||
|
# num_images_per_prompt=1, # currently only single image
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
control_mode="balanced",
|
||||||
|
resize_mode="just_resize_simple",
|
||||||
|
):
|
||||||
|
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||||
|
if (resize_mode == "just_resize_simple" or
|
||||||
|
resize_mode == "crop_resize_simple" or
|
||||||
|
resize_mode == "fill_resize_simple"):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
if (resize_mode == "just_resize_simple"):
|
||||||
|
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||||
|
pass
|
||||||
|
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||||
|
pass
|
||||||
|
nimage = np.array(image)
|
||||||
|
nimage = nimage[None, :]
|
||||||
|
nimage = np.concatenate([nimage], axis=0)
|
||||||
|
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
|
||||||
|
nimage = np.array(nimage).astype(np.float32) / 255.0
|
||||||
|
nimage = nimage.transpose(0, 3, 1, 2)
|
||||||
|
timage = torch.from_numpy(nimage)
|
||||||
|
|
||||||
|
# use fancy lvmin controlnet resizing
|
||||||
|
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||||
|
nimage = np.array(image)
|
||||||
|
timage, nimage = np_img_resize(
|
||||||
|
np_img=nimage,
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
h=height,
|
||||||
|
w=width,
|
||||||
|
# device=torch.device('cpu')
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
timage = timage.to(device=device, dtype=dtype)
|
||||||
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
|
timage = torch.cat([timage] * 2)
|
||||||
|
return timage
|
||||||
|
Loading…
Reference in New Issue
Block a user