minor fixes to inpaint code

1. If tensors are passed to inpaint as init_image and/or init_mask, then
   the post-generation image fixup code will be skipped.

2. Post-generation image fixup will work with either a black and white "L"
   or "RGB"  mask, or an "RGBA" mask.
This commit is contained in:
Lincoln Stein 2022-10-22 22:28:54 -04:00
parent 9472945299
commit 7e27f189cf
2 changed files with 21 additions and 15 deletions

View File

@ -1,21 +1,22 @@
# This file describes the alternative machine learning models
# available to the dream script.
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4
width: 512
height: 512
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4
width: 512
height: 512
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
description: Stable Diffusion inference model version 1.5
width: 512
height: 512
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
description: Stable Diffusion inference model version 1.5
width: 512
height: 512
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
default: true

View File

@ -49,6 +49,7 @@ class Inpaint(Img2Img):
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
@ -110,13 +111,17 @@ class Inpaint(Img2Img):
def sample_to_image(self, samples)->Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask
pil_init_mask = pil_mask.convert('L')
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.