diff --git a/ldm/generate.py b/ldm/generate.py index 7346f8dfe2..db36717135 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -806,6 +806,7 @@ class Generate: if not self.generators.get('inpaint'): from ldm.invoke.generator.inpaint import Inpaint self.generators['inpaint'] = Inpaint(self.model, self.precision) + self.generators['inpaint'].free_gpu_mem = self.free_gpu_mem return self.generators['inpaint'] # "omnibus" supports the runwayML custom inpainting model, which does diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 7b4d151265..b86e21efe6 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -59,7 +59,7 @@ class Inpaint(Img2Img): writeable=False ) - def infill_patchmatch(self, im: Image.Image) -> Image: + def infill_patchmatch(self, im: Image.Image) -> Image: if im.mode != 'RGBA': return im @@ -128,7 +128,7 @@ class Inpaint(Img2Img): # Combine npmask = npgradient + npedge - # Expand + # Expand npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) new_mask = Image.fromarray(npmask) @@ -221,7 +221,7 @@ class Inpaint(Img2Img): init_filled = init_filled.resize((inpaint_width, inpaint_height)) debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) - + # Create init tensor init_image = self._image_to_tensor(init_filled.convert('RGB')) @@ -254,7 +254,7 @@ class Inpaint(Img2Img): f">> Using recommended DDIM sampler for inpainting." ) sampler = DDIMSampler(self.model, device=self.model.device) - + sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) @@ -291,6 +291,9 @@ class Inpaint(Img2Img): masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise z_enc = z_enc * mask_image + masked_region + if self.free_gpu_mem and self.model.model.device != self.model.device: + self.model.model.to(self.model.device) + # decode it samples = sampler.decode( z_enc, @@ -353,7 +356,7 @@ class Inpaint(Img2Img): if self.pil_image is None or self.pil_mask is None: return gen_result - + corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)