mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* Removed duplicate fix_func for MPS * add support for loading VAE autoencoders To add a VAE autoencoder to an existing model: 1. Download the appropriate autoencoder and put it into models/ldm/stable-diffusion Note that you MUST use a VAE that was written for the original CompViz Stable Diffusion codebase. For v1.4, that would be the file named vae-ft-mse-840000-ema-pruned.ckpt that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original 2. Edit config/models.yaml to contain the following stanza, modifying `weights` and `vae` as required to match the weights and vae model file names. There is no requirement to rename the VAE file. ~~~ stable-diffusion-1.4: weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt description: Stable Diffusion v1.4 config: configs/stable-diffusion/v1-inference.yaml vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 height: 512 ~~~ 3. Alternatively from within the `invoke.py` CLI, you may use the command `!editmodel stable-diffusion-1.4` to bring up a simple editor that will allow you to add the path to the VAE. 4. If you are just installing InvokeAI for the first time, you can also use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead to create the configuration from scratch. 5. That's it! * ported code refactor changes from PR #1221 - pass a PIL.Image to img2img and inpaint rather than tensor - To support clipseg, inpaint needs to accept an "L" or "1" format mask. Made the appropriate change. * minor fixes to inpaint code 1. If tensors are passed to inpaint as init_image and/or init_mask, then the post-generation image fixup code will be skipped. 2. Post-generation image fixup will work with either a black and white "L" or "RGB" mask, or an "RGBA" mask. Co-authored-by: wfng92 <43742196+wfng92@users.noreply.github.com>
This commit is contained in:
@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
import PIL
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -25,6 +28,9 @@ class Img2Img(Generator):
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
@ -68,3 +74,11 @@ class Img2Img(Generator):
|
||||
shape = init_latent.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
@ -13,16 +14,19 @@ from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
self.mask_blur_radius = 0
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
pil_image: Image.Image, pil_mask: Image.Image,
|
||||
mask_blur_radius: int = 8,
|
||||
step_callback=None,inpaint_replace=False, **kwargs):
|
||||
"""
|
||||
@ -31,17 +35,22 @@ class Inpaint(Img2Img):
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
# Get the alpha channel of the mask
|
||||
pil_init_mask = pil_mask.getchannel('A')
|
||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||
# histogram and cause slight color changes.
|
||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
@ -96,30 +105,50 @@ class Inpaint(Img2Img):
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
|
||||
# Get PIL result
|
||||
gen_result = self.sample_to_image(samples).convert('RGB')
|
||||
|
||||
# Get numpy version
|
||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||
|
||||
return matched_result
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def sample_to_image(self, samples)->Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
pil_mask = self.pil_mask
|
||||
pil_image = self.pil_image
|
||||
mask_blur_radius = self.mask_blur_radius
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||
# histogram and cause slight color changes.
|
||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||
|
||||
# Get numpy version
|
||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
@ -13,6 +13,7 @@ import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import transformers
|
||||
import os
|
||||
from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
@ -193,6 +194,7 @@ class ModelCache(object):
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae',None)
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
@ -222,9 +224,17 @@ class ModelCache(object):
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae and os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
|
Reference in New Issue
Block a user