From 94729452999448db481aa190b35bce44a553f1b9 Mon Sep 17 00:00:00 2001
From: Lincoln Stein <lincoln.stein@gmail.com>
Date: Sat, 22 Oct 2022 20:06:45 -0400
Subject: [PATCH] ported code refactor changes from PR #1221

- pass a PIL.Image to img2img and inpaint rather than tensor
- To support clipseg, inpaint needs to accept an "L" or "1" format
  mask. Made the appropriate change.
---
 ldm/generate.py                 | 54 ++++---------------
 ldm/invoke/generator/img2img.py | 20 +++++--
 ldm/invoke/generator/inpaint.py | 96 ++++++++++++++++++++-------------
 3 files changed, 88 insertions(+), 82 deletions(-)

diff --git a/ldm/generate.py b/ldm/generate.py
index 964863ce6e..ce2331806c 100644
--- a/ldm/generate.py
+++ b/ldm/generate.py
@@ -393,7 +393,7 @@ class Generate:
                 log_tokens    =self.log_tokenization
             )
 
-            init_image,mask_image,pil_image,pil_mask = self._make_images(
+            init_image, mask_image = self._make_images(
                 init_img,
                 init_mask,
                 width,
@@ -433,8 +433,6 @@ class Generate:
                 height=height,
                 init_img=init_img,        # embiggen needs to manipulate from the unmodified init_img
                 init_image=init_image,      # notice that init_image is different from init_img
-                pil_image=pil_image,
-                pil_mask=pil_mask,
                 mask_image=mask_image,
                 strength=strength,
                 threshold=threshold,
@@ -626,7 +624,7 @@ class Generate:
         init_image      = None
         init_mask       = None
         if not img:
-            return None, None, None, None
+            return None, None
 
         image = self._load_img(img)
 
@@ -636,23 +634,22 @@ class Generate:
         # if image has a transparent area and no mask was provided, then try to generate mask
         if self._has_transparency(image):
             self._transparency_check_and_warning(image, mask)
-            # this returns a torch tensor
             init_mask = self._create_init_mask(image, width, height, fit=fit)
             
         if (image.width * image.height) > (self.width * self.height) and self.size_matters:
             print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
             self.size_matters = False
 
-        init_image   = self._create_init_image(image,width,height,fit=fit)                   # this returns a torch tensor
+        init_image   = self._create_init_image(image,width,height,fit=fit)
 
         if mask:
-            mask_image = self._load_img(mask)  # this returns an Image
+            mask_image = self._load_img(mask)
             init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
 
         elif text_mask:
             init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
 
-        return init_image, init_mask, image, mask_image
+        return init_image,init_mask
 
     def _make_base(self):
         if not self.generators.get('base'):
@@ -869,33 +866,15 @@ class Generate:
 
     def _create_init_image(self, image, width, height, fit=True):
         image = image.convert('RGB')
-        if fit:
-            image = self._fit_image(image, (width, height))
-        else:
-            image = self._squeeze_image(image)
-        image = np.array(image).astype(np.float32) / 255.0
-        image = image[None].transpose(0, 3, 1, 2)
-        image = torch.from_numpy(image)
-        image = 2.0 * image - 1.0
-        return image.to(self.device)
+        image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
+        return image
 
     def _create_init_mask(self, image, width, height, fit=True):
         # convert into a black/white mask
         image = self._image_to_mask(image)
         image = image.convert('RGB')
-
-        # now we adjust the size
-        if fit:
-            image = self._fit_image(image, (width, height))
-        else:
-            image = self._squeeze_image(image)
-        image = image.resize((image.width//downsampling, image.height //
-                              downsampling), resample=Image.Resampling.NEAREST)
-        image = np.array(image)
-        image = image.astype(np.float32) / 255.0
-        image = image[None].transpose(0, 3, 1, 2)
-        image = torch.from_numpy(image)
-        return image.to(self.device)
+        image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
+        return image
 
     # The mask is expected to have the region to be inpainted
     # with alpha transparency. It converts it into a black/white
@@ -912,7 +891,6 @@ class Generate:
             mask = ImageOps.invert(mask)
         return mask
 
-    # TODO: The latter part of this method repeats code from _create_init_mask()
     def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
         prompt = text_mask[0]
         confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@@ -922,18 +900,8 @@ class Generate:
         segmented = self.txt2mask.segment(image, prompt)
         mask = segmented.to_mask(float(confidence_level))
         mask = mask.convert('RGB')
-        # now we adjust the size
-        if fit:
-            mask = self._fit_image(mask, (width, height))
-        else:
-            mask = self._squeeze_image(mask)
-        mask = mask.resize((mask.width//downsampling, mask.height //
-                              downsampling), resample=Image.Resampling.NEAREST)
-        mask = np.array(mask)
-        mask = mask.astype(np.float32) / 255.0
-        mask = mask[None].transpose(0, 3, 1, 2)
-        mask = torch.from_numpy(mask)
-        return mask.to(self.device)
+        mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
+        return mask
 
     def _has_transparency(self, image):
         if image.info.get("transparency", None) is not None:
diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py
index 7fde1a94cf..613f1aca31 100644
--- a/ldm/invoke/generator/img2img.py
+++ b/ldm/invoke/generator/img2img.py
@@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
 
 import torch
 import numpy as  np
-from ldm.invoke.devices             import choose_autocast
-from ldm.invoke.generator.base      import Generator
-from ldm.models.diffusion.ddim     import DDIMSampler
+import PIL
+from torch import Tensor
+from PIL import Image
+from ldm.invoke.devices import choose_autocast
+from ldm.invoke.generator.base import Generator
+from ldm.models.diffusion.ddim import DDIMSampler
 
 class Img2Img(Generator):
     def __init__(self, model, precision):
@@ -25,6 +28,9 @@ class Img2Img(Generator):
             ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
         )
 
+        if isinstance(init_image, PIL.Image.Image):
+            init_image = self._image_to_tensor(init_image)
+
         scope = choose_autocast(self.precision)
         with scope(self.model.device.type):
             self.init_latent = self.model.get_first_stage_encoding(
@@ -68,3 +74,11 @@ class Img2Img(Generator):
             shape = init_latent.shape
             x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
         return x
+
+    def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
+        image = np.array(image).astype(np.float32) / 255.0
+        image = image[None].transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image)
+        if normalize:
+            image = 2.0 * image - 1.0
+        return image.to(self.model.device)    
diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py
index ee67b90c46..fa0200e185 100644
--- a/ldm/invoke/generator/inpaint.py
+++ b/ldm/invoke/generator/inpaint.py
@@ -6,6 +6,7 @@ import torch
 import torchvision.transforms as T
 import numpy as  np
 import cv2 as cv
+import PIL
 from PIL import Image, ImageFilter
 from skimage.exposure.histogram_matching import match_histograms
 from einops import rearrange, repeat
@@ -13,16 +14,19 @@ from ldm.invoke.devices             import choose_autocast
 from ldm.invoke.generator.img2img   import Img2Img
 from ldm.models.diffusion.ddim     import DDIMSampler
 from ldm.models.diffusion.ksampler import KSampler
+from ldm.invoke.generator.base import downsampling
 
 class Inpaint(Img2Img):
     def __init__(self, model, precision):
         self.init_latent = None
+        self.pil_image = None
+        self.pil_mask = None
+        self.mask_blur_radius = 0
         super().__init__(model, precision)
 
     @torch.no_grad()
     def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
                        conditioning,init_image,mask_image,strength,
-                       pil_image: Image.Image, pil_mask: Image.Image,
                        mask_blur_radius: int = 8,
                        step_callback=None,inpaint_replace=False, **kwargs):
         """
@@ -31,17 +35,21 @@ class Inpaint(Img2Img):
         the time you call it.  kwargs are 'init_latent' and 'strength'
         """
 
-        # Get the alpha channel of the mask
-        pil_init_mask = pil_mask.getchannel('A')
-        pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
+        if isinstance(init_image, PIL.Image.Image):
+            self.pil_image = init_image
+            init_image = self._image_to_tensor(init_image)
 
-        # Build an image with only visible pixels from source to use as reference for color-matching.
-        # Note that this doesn't use the mask, which would exclude some source image pixels from the
-        # histogram and cause slight color changes.
-        init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
-        init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
-        init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
-        init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
+        if isinstance(mask_image, PIL.Image.Image):
+            self.pil_mask = mask_image
+            mask_image = mask_image.resize(
+                (
+                    mask_image.width // downsampling,
+                    mask_image.height // downsampling
+                ),
+                resample=Image.Resampling.NEAREST
+            )
+            mask_image = self._image_to_tensor(mask_image,normalize=False)
+        self.mask_blur_radius = mask_blur_radius
 
         # klms samplers not supported yet, so ignore previous sampler
         if isinstance(sampler,KSampler):
@@ -96,30 +104,46 @@ class Inpaint(Img2Img):
                 mask                       = mask_image,
                 init_latent                = self.init_latent
             )
-
-            # Get PIL result
-            gen_result = self.sample_to_image(samples).convert('RGB')
-
-            # Get numpy version
-            np_gen_result = np.asarray(gen_result, dtype=np.uint8)
-
-            # Color correct
-            np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
-            matched_result = Image.fromarray(np_matched_result, mode='RGB')
-
-
-            # Blur the mask out (into init image) by specified amount
-            if mask_blur_radius > 0:
-                nm = np.asarray(pil_init_mask, dtype=np.uint8)
-                nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
-                pmd = Image.fromarray(nmd, mode='L')
-                blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
-            else:
-                blurred_init_mask = pil_init_mask
-
-            # Paste original on color-corrected generation (using blurred mask)
-            matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
-
-            return matched_result
+            return self.sample_to_image(samples)
 
         return make_image
+
+    def sample_to_image(self, samples)->Image:
+        gen_result = super().sample_to_image(samples).convert('RGB')
+        
+        pil_mask = self.pil_mask
+        pil_image = self.pil_image
+        mask_blur_radius = self.mask_blur_radius
+
+        # Get the original alpha channel of the mask
+        pil_init_mask = pil_mask.convert('L')
+        pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
+
+        # Build an image with only visible pixels from source to use as reference for color-matching.
+        # Note that this doesn't use the mask, which would exclude some source image pixels from the
+        # histogram and cause slight color changes.
+        init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
+        init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
+        init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
+        init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
+
+        # Get numpy version
+        np_gen_result = np.asarray(gen_result, dtype=np.uint8)
+
+        # Color correct
+        np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
+        matched_result = Image.fromarray(np_matched_result, mode='RGB')
+
+        # Blur the mask out (into init image) by specified amount
+        if mask_blur_radius > 0:
+            nm = np.asarray(pil_init_mask, dtype=np.uint8)
+            nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
+            pmd = Image.fromarray(nmd, mode='L')
+            blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
+        else:
+            blurred_init_mask = pil_init_mask
+
+        # Paste original on color-corrected generation (using blurred mask)
+        matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
+        return matched_result
+