Support color correction for img2img and inpainting (#613)

* Support color correction for img2img and inpainting, avoiding the shift to magenta seen when running images through img2img repeatedly.

* Fix docs for color correction

* add --init_color to prompt reconstruction

* For best results, the --init_color option should point to the *very first* image used in the sequence of img2img operations. Otherwise color correction will skew towards cyan.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Danny Beer
2022-09-18 09:47:57 -04:00
committed by GitHub
parent 0a4397094e
commit 045aa7a9a3
6 changed files with 61 additions and 4 deletions

View File

@ -181,6 +181,10 @@ class Args(object):
switches.append('--seamless')
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0:
switches.append(f'-M {a["init_mask"]}')
if a['init_color'] and len(a['init_color'])>0:
switches.append(f'--init_color {a["init_color"]}')
if a['fit']:
switches.append(f'--fit')
if a['init_img'] and a['strength'] and a['strength']>0:
@ -493,6 +497,11 @@ class Args(object):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
img2img_group.add_argument(
'-T',
'-fit',

View File

@ -22,7 +22,8 @@ class Completer:
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img','-M','--init_mask')):
if text.startswith(('-I', '--init_img','-M','--init_mask',
'--init_color')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
@ -57,6 +58,8 @@ class Completer:
path = text.replace('--init_mask=', '', 1).lstrip()
elif text.startswith('-M'):
path = text.replace('-M', '', 1).lstrip()
elif text.startswith('--init_color='):
path = text.replace('--init_color=', '', 1).lstrip()
else:
path = text
@ -100,6 +103,7 @@ if readline_available:
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',

View File

@ -15,6 +15,8 @@ import traceback
import transformers
import io
import hashlib
import cv2
import skimage
from omegaconf import OmegaConf
from PIL import Image, ImageOps
@ -220,6 +222,7 @@ class Generate:
init_mask = None,
fit = False,
strength = None,
init_color = None,
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
@ -362,6 +365,11 @@ class Generate:
embiggen_tiles = embiggen_tiles,
)
if init_color:
self.correct_colors(image_list = results,
reference_image_path = init_color,
image_callback = image_callback)
if upscale is not None or gfpgan_strength > 0:
self.upscale_and_reconstruct(results,
upscale = upscale,
@ -475,6 +483,28 @@ class Generate:
return self.model
def correct_colors(self,
image_list,
reference_image_path,
image_callback = None):
reference_image = Image.open(reference_image_path)
correction_target = cv2.cvtColor(np.asarray(reference_image),
cv2.COLOR_RGB2LAB)
for r in image_list:
image, seed = r
image = cv2.cvtColor(np.asarray(image),
cv2.COLOR_RGB2LAB)
image = skimage.exposure.match_histograms(image,
correction_target,
channel_axis=2)
image = Image.fromarray(
cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8")
)
if image_callback is not None:
image_callback(image, seed)
else:
r[0] = image
def upscale_and_reconstruct(self,
image_list,
upscale = None,