diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index c5a14f7ed0..95d07921ac 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -40,6 +40,8 @@ def parameters_to_command(params): switches.append(f'-I {params["init_img"]}') if 'init_mask' in params and len(params['init_mask']) > 0: switches.append(f'-M {params["init_mask"]}') + if 'init_color' in params and len(params['init_color']) > 0: + switches.append(f'--init_color {params["init_color"]}') if 'strength' in params and 'init_img' in params: switches.append(f'-f {params["strength"]}') if 'fit' in params and params["fit"] == True: @@ -129,6 +131,11 @@ def create_cmd_parser(): type=str, help='Path to input mask for inpainting mode (supersedes width and height)', ) + parser.add_argument( + '--init_color', + type=str, + help='Path to reference image for color correction (used for repeated img2img and inpainting)' + ) parser.add_argument( '-T', '-fit', diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 081af3dbf4..5f7cdaf162 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -154,13 +154,19 @@ vary greatly depending on what is in the image. We also ask to --fit the image i than 640x480. Otherwise the image size will be identical to the provided photo and you may run out of memory if it is large. +Repeated chaining of img2img on an image can result in significant color shifts +in the output, especially if run with lower strength. Color correction can be +run against a reference image to fix this issue. Use the original input image to the +chain as the the reference image for each step in the chain. + In addition to the command-line options recognized by txt2img, img2img accepts additional options: | Argument | Shortcut | Default | Description | | ------------------ | --------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------ | -| --init_img | -I | None | Path to the initialization image | -| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions | -| --strength | -s | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. | +| --init_img | -I | None | Path to the initialization image | +| --init_color | | None | Path to reference image for color correction | +| --fit | -F | False | Scale the image to fit into the specified -H and -W dimensions | +| --strength | -s | 0.75 | How hard to try to match the prompt to the initial image. Ranges from 0.0-0.99, with higher values replacing the initial image completely. | ### This is an example of inpainting diff --git a/ldm/dream/args.py b/ldm/dream/args.py index d33a8039b2..5ede5209f6 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -181,6 +181,10 @@ class Args(object): switches.append('--seamless') if a['init_img'] and len(a['init_img'])>0: switches.append(f'-I {a["init_img"]}') + if a['init_mask'] and len(a['init_mask'])>0: + switches.append(f'-M {a["init_mask"]}') + if a['init_color'] and len(a['init_color'])>0: + switches.append(f'--init_color {a["init_color"]}') if a['fit']: switches.append(f'--fit') if a['init_img'] and a['strength'] and a['strength']>0: @@ -493,6 +497,11 @@ class Args(object): type=str, help='Path to input mask for inpainting mode (supersedes width and height)', ) + img2img_group.add_argument( + '--init_color', + type=str, + help='Path to reference image for color correction (used for repeated img2img and inpainting)' + ) img2img_group.add_argument( '-T', '-fit', diff --git a/ldm/dream/readline.py b/ldm/dream/readline.py index 2aa8520acf..da94f5a61f 100644 --- a/ldm/dream/readline.py +++ b/ldm/dream/readline.py @@ -22,7 +22,8 @@ class Completer: def complete(self, text, state): buffer = readline.get_line_buffer() - if text.startswith(('-I', '--init_img','-M','--init_mask')): + if text.startswith(('-I', '--init_img','-M','--init_mask', + '--init_color')): return self._path_completions(text, state, ('.png','.jpg','.jpeg')) if buffer.strip().endswith('cd') or text.startswith(('.', '/')): @@ -57,6 +58,8 @@ class Completer: path = text.replace('--init_mask=', '', 1).lstrip() elif text.startswith('-M'): path = text.replace('-M', '', 1).lstrip() + elif text.startswith('--init_color='): + path = text.replace('--init_color=', '', 1).lstrip() else: path = text @@ -100,6 +103,7 @@ if readline_available: '--individual','-i', '--init_img','-I', '--init_mask','-M', + '--init_color', '--strength','-f', '--variants','-v', '--outdir','-o', diff --git a/ldm/generate.py b/ldm/generate.py index 1b3c8544e0..8bb40d0553 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -15,6 +15,8 @@ import traceback import transformers import io import hashlib +import cv2 +import skimage from omegaconf import OmegaConf from PIL import Image, ImageOps @@ -220,6 +222,7 @@ class Generate: init_mask = None, fit = False, strength = None, + init_color = None, # these are specific to embiggen (which also relies on img2img args) embiggen = None, embiggen_tiles = None, @@ -362,6 +365,11 @@ class Generate: embiggen_tiles = embiggen_tiles, ) + if init_color: + self.correct_colors(image_list = results, + reference_image_path = init_color, + image_callback = image_callback) + if upscale is not None or gfpgan_strength > 0: self.upscale_and_reconstruct(results, upscale = upscale, @@ -475,6 +483,28 @@ class Generate: return self.model + def correct_colors(self, + image_list, + reference_image_path, + image_callback = None): + reference_image = Image.open(reference_image_path) + correction_target = cv2.cvtColor(np.asarray(reference_image), + cv2.COLOR_RGB2LAB) + for r in image_list: + image, seed = r + image = cv2.cvtColor(np.asarray(image), + cv2.COLOR_RGB2LAB) + image = skimage.exposure.match_histograms(image, + correction_target, + channel_axis=2) + image = Image.fromarray( + cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8") + ) + if image_callback is not None: + image_callback(image, seed) + else: + r[0] = image + def upscale_and_reconstruct(self, image_list, upscale = None, diff --git a/requirements.txt b/requirements.txt index efc55b7971..d0b739f82a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ pillow pip>=22 pudb pytorch-lightning +scikit-image>=0.19 streamlit # "CompVis/taming-transformers" IS NOT INSTALLABLE # This is a drop-in replacement