Support color correction for img2img and inpainting (#613)

* Support color correction for img2img and inpainting, avoiding the shift to magenta seen when running images through img2img repeatedly.

* Fix docs for color correction

* add --init_color to prompt reconstruction

* For best results, the --init_color option should point to the *very first* image used in the sequence of img2img operations. Otherwise color correction will skew towards cyan.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Danny Beer 2022-09-18 09:47:57 -04:00 committed by GitHub
parent 0a4397094e
commit 045aa7a9a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 4 deletions

View File

@ -40,6 +40,8 @@ def parameters_to_command(params):
switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0:
switches.append(f'-M {params["init_mask"]}')
if 'init_color' in params and len(params['init_color']) > 0:
switches.append(f'--init_color {params["init_color"]}')
if 'strength' in params and 'init_img' in params:
switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True:
@ -129,6 +131,11 @@ def create_cmd_parser():
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
parser.add_argument(
'-T',
'-fit',

View File

@ -154,13 +154,19 @@ vary greatly depending on what is in the image. We also ask to --fit the image i
than 640x480. Otherwise the image size will be identical to the provided photo and you may run out
of memory if it is large.
Repeated chaining of img2img on an image can result in significant color shifts
in the output, especially if run with lower strength. Color correction can be
run against a reference image to fix this issue. Use the original input image to the
chain as the the reference image for each step in the chain.
In addition to the command-line options recognized by txt2img, img2img accepts additional options:
| Argument | Shortcut | Default | Description |
| ------------------ | --------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| --init_img <path> | -I<path> | None | Path to the initialization image |
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
| --init_img <path> | -I<path> | None | Path to the initialization image |
| --init_color <path> | | None | Path to reference image for color correction |
| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions |
| --strength <float> | -s<float> | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. |
### This is an example of inpainting

View File

@ -181,6 +181,10 @@ class Args(object):
switches.append('--seamless')
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0:
switches.append(f'-M {a["init_mask"]}')
if a['init_color'] and len(a['init_color'])>0:
switches.append(f'--init_color {a["init_color"]}')
if a['fit']:
switches.append(f'--fit')
if a['init_img'] and a['strength'] and a['strength']>0:
@ -493,6 +497,11 @@ class Args(object):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
img2img_group.add_argument(
'-T',
'-fit',

View File

@ -22,7 +22,8 @@ class Completer:
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img','-M','--init_mask')):
if text.startswith(('-I', '--init_img','-M','--init_mask',
'--init_color')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
@ -57,6 +58,8 @@ class Completer:
path = text.replace('--init_mask=', '', 1).lstrip()
elif text.startswith('-M'):
path = text.replace('-M', '', 1).lstrip()
elif text.startswith('--init_color='):
path = text.replace('--init_color=', '', 1).lstrip()
else:
path = text
@ -100,6 +103,7 @@ if readline_available:
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',

View File

@ -15,6 +15,8 @@ import traceback
import transformers
import io
import hashlib
import cv2
import skimage
from omegaconf import OmegaConf
from PIL import Image, ImageOps
@ -220,6 +222,7 @@ class Generate:
init_mask = None,
fit = False,
strength = None,
init_color = None,
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
@ -362,6 +365,11 @@ class Generate:
embiggen_tiles = embiggen_tiles,
)
if init_color:
self.correct_colors(image_list = results,
reference_image_path = init_color,
image_callback = image_callback)
if upscale is not None or gfpgan_strength > 0:
self.upscale_and_reconstruct(results,
upscale = upscale,
@ -475,6 +483,28 @@ class Generate:
return self.model
def correct_colors(self,
image_list,
reference_image_path,
image_callback = None):
reference_image = Image.open(reference_image_path)
correction_target = cv2.cvtColor(np.asarray(reference_image),
cv2.COLOR_RGB2LAB)
for r in image_list:
image, seed = r
image = cv2.cvtColor(np.asarray(image),
cv2.COLOR_RGB2LAB)
image = skimage.exposure.match_histograms(image,
correction_target,
channel_axis=2)
image = Image.fromarray(
cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8")
)
if image_callback is not None:
image_callback(image, seed)
else:
r[0] = image
def upscale_and_reconstruct(self,
image_list,
upscale = None,

View File

@ -14,6 +14,7 @@ pillow
pip>=22
pudb
pytorch-lightning
scikit-image>=0.19
streamlit
# "CompVis/taming-transformers" IS NOT INSTALLABLE
# This is a drop-in replacement