Merge branch 'inpaint-improvement' of https://github.com/Kyle0654/InvokeAI into inpaint-improvement

This commit is contained in:
Kyle Schouviller 2022-10-23 14:02:52 -07:00
commit 0c34554170
6 changed files with 128 additions and 115 deletions

View File

@ -5,7 +5,6 @@
# model requires a model config file, a weights file, # model requires a model config file, a weights file,
# and the width and height of the images it # and the width and height of the images it
# was trained on. # was trained on.
stable-diffusion-1.4: stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
@ -19,3 +18,4 @@ stable-diffusion-1.5:
description: Stable Diffusion inference model version 1.5 description: Stable Diffusion inference model version 1.5
width: 512 width: 512
height: 512 height: 512
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt

View File

@ -58,24 +58,6 @@ torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined # this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4' FALLBACK_MODEL_NAME='stable-diffusion-1.4'
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
Example Usage: Example Usage:
@ -411,7 +393,7 @@ class Generate:
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
init_image,mask_image,pil_image,pil_mask = self._make_images( init_image, mask_image = self._make_images(
init_img, init_img,
init_mask, init_mask,
width, width,
@ -451,8 +433,6 @@ class Generate:
height=height, height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img init_image=init_image, # notice that init_image is different from init_img
pil_image=pil_image,
pil_mask=pil_mask,
mask_image=mask_image, mask_image=mask_image,
strength=strength, strength=strength,
threshold=threshold, threshold=threshold,
@ -644,7 +624,7 @@ class Generate:
init_image = None init_image = None
init_mask = None init_mask = None
if not img: if not img:
return None, None, None, None return None, None
image = self._load_img(img) image = self._load_img(img)
@ -654,23 +634,22 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask) self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor init_image = self._create_init_image(image,width,height,fit=fit)
if mask: if mask:
mask_image = self._load_img(mask) # this returns an Image mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit) init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask, image, mask_image return init_image,init_mask
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
@ -887,33 +866,15 @@ class Generate:
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB') image = image.convert('RGB')
if fit: image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
image = self._fit_image(image, (width, height)) return image
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_mask(self, image, width, height, fit=True): def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask # convert into a black/white mask
image = self._image_to_mask(image) image = self._image_to_mask(image)
image = image.convert('RGB') image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
# now we adjust the size return image
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
# The mask is expected to have the region to be inpainted # The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white # with alpha transparency. It converts it into a black/white
@ -930,7 +891,6 @@ class Generate:
mask = ImageOps.invert(mask) mask = ImageOps.invert(mask)
return mask return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0] prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@ -940,18 +900,8 @@ class Generate:
segmented = self.txt2mask.segment(image, prompt) segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level)) mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB') mask = mask.convert('RGB')
# now we adjust the size mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
if fit: return mask
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
def _has_transparency(self, image): def _has_transparency(self, image):
if image.info.get("transparency", None) is not None: if image.info.get("transparency", None) is not None:

View File

@ -4,6 +4,9 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
import PIL
from torch import Tensor
from PIL import Image
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
@ -25,6 +28,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
@ -68,3 +74,11 @@ class Img2Img(Generator):
shape = init_latent.shape shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -6,6 +6,7 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv import cv2 as cv
import PIL
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
@ -13,16 +14,19 @@ from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self, model, precision): def __init__(self, model, precision):
self.init_latent = None self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
pil_image: Image.Image, pil_mask: Image.Image,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs): step_callback=None,inpaint_replace=False, **kwargs):
""" """
@ -31,17 +35,22 @@ class Inpaint(Img2Img):
the time you call it. kwargs are 'init_latent' and 'strength' the time you call it. kwargs are 'init_latent' and 'strength'
""" """
# Get the alpha channel of the mask if isinstance(init_image, PIL.Image.Image):
pil_init_mask = pil_mask.getchannel('A') self.pil_image = init_image
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist init_image = self._image_to_tensor(init_image)
# Build an image with only visible pixels from source to use as reference for color-matching. if isinstance(mask_image, PIL.Image.Image):
# Note that this doesn't use the mask, which would exclude some source image pixels from the self.pil_mask = mask_image
# histogram and cause slight color changes. mask_image = mask_image.resize(
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) (
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) mask_image.width // downsampling,
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] mask_image.height // downsampling
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram ),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler # klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler): if isinstance(sampler,KSampler):
@ -96,9 +105,32 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples)
# Get PIL result return make_image
gen_result = self.sample_to_image(samples).convert('RGB')
def sample_to_image(self, samples)->Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version # Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8) np_gen_result = np.asarray(gen_result, dtype=np.uint8)
@ -107,7 +139,6 @@ class Inpaint(Img2Img):
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB') matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount # Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0: if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8) nm = np.asarray(pil_init_mask, dtype=np.uint8)
@ -119,7 +150,5 @@ class Inpaint(Img2Img):
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result return matched_result
return make_image

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
@ -193,6 +194,7 @@ class ModelCache(object):
mconfig = self.config[model_name] mconfig = self.config[model_name]
config = mconfig.config config = mconfig.config
weights = mconfig.weights weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
@ -222,9 +224,17 @@ class ModelCache(object):
else: else:
print(' | Using more accurate float32 precision') print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
model.to(self.device) model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device model.cond_stage_model.device = self.device
model.eval() model.eval()
for m in model.modules(): for m in model.modules():

View File

@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
new_config['config'] = input('Configuration file for this model: ') new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config']) done = os.path.exists(new_config['config'])
done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0
completer.complete_extensions(None) completer.complete_extensions(None)
for field in ('width','height'): for field in ('width','height'):
@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer):
conf = config[model_name] conf = config[model_name]
new_config = {} new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'config', 'width','height'): for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else '' completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ') new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value new_config[field] = int(new_value) if field in ('width','height') else new_value