defer patchmatch loading (#2039)

* defer patchmatch loading

Because of the way that patchmatch was loaded early at import time, it
was impossible to turn off the attempted loading with --no-patchmatch.

In addition, the patchmatch loading messages appear early on during
initialization, interfering with ability to print out the version
cleanly when --version provided to invoke script.

This commit creates a thin wrapper class for patch_match that is only
loaded when needed, solving both problems.

* create a singleton patchmatch object for use in inpainting

This creates a thin wrapper to patchmatch which loads the module
on demand, respecting the global "trypatchmatch" option.

* address 2d round of issues in PR 2039 comments

* Patchmatch->PatchMatch and misc cleanup
This commit is contained in:
Lincoln Stein 2022-12-20 18:32:35 -05:00 committed by GitHub
parent 464aafa862
commit cca8d14c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 23 deletions

View File

@ -294,7 +294,7 @@ class InvokeAIWebServer:
print(f">> System config requested")
config = self.get_system_config()
config["model_list"] = self.generate.model_cache.list_models()
config["infill_methods"] = infill_methods
config["infill_methods"] = infill_methods()
socketio.emit("systemConfig", config)
@socketio.on("requestModelChange")

View File

@ -190,6 +190,7 @@ class Generate:
self.txt2mask = None
self.safety_checker = None
self.karras_max = None
self.infill_method = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -326,7 +327,7 @@ class Generate:
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
infill_method = infill_methods[0], # The infill method to use
infill_method = None,
force_outpaint: bool = False,
enable_image_debugging = False,
@ -392,6 +393,7 @@ class Generate:
self.log_tokenization = log_tokenization
self.step_callback = step_callback
self.karras_max = karras_max
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache

View File

@ -17,22 +17,15 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
from ldm.util import debug_image
from ldm.invoke.patchmatch import PatchMatch
from ldm.invoke.globals import Globals
infill_methods: list[str] = list()
if Globals.try_patchmatch:
from patchmatch import patch_match
if patch_match.patchmatch_available:
print('>> Patchmatch initialized')
infill_methods.append('patchmatch')
else:
print('>> Patchmatch not loaded (nonfatal)')
else:
print('>> Patchmatch loading disabled')
infill_methods.append('tile')
def infill_methods()->list[str]:
methods = list()
if PatchMatch.patchmatch_available():
methods.append('patchmatch')
methods.append('tile')
return methods
class Inpaint(Img2Img):
def __init__(self, model, precision):
@ -40,6 +33,7 @@ class Inpaint(Img2Img):
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
self.infill_method = None
super().__init__(model, precision)
# Outpaint support code
@ -64,11 +58,11 @@ class Inpaint(Img2Img):
return im
# Skip patchmatch if patchmatch isn't available
if not patch_match.patchmatch_available:
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
return im_patched
@ -187,7 +181,7 @@ class Inpaint(Img2Img):
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, enable_image_debugging=False,
infill_method = infill_methods[0], # The infill method to use
infill_method = None,
inpaint_width=None,
inpaint_height=None,
**kwargs):
@ -198,6 +192,7 @@ class Inpaint(Img2Img):
"""
self.enable_image_debugging = enable_image_debugging
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height
@ -206,7 +201,7 @@ class Inpaint(Img2Img):
self.pil_image = init_image.copy()
# Do infill
if infill_method == 'patchmatch' and patch_match.patchmatch_available:
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
init_filled = self.infill_patchmatch(self.pil_image.copy())
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
init_filled = self.tile_fill_missing(

View File

@ -27,6 +27,5 @@ else:
# Where to look for the initialization file
Globals.initfile = 'invokeai.init'
# Awkward workaround to disable attempted loading of pypatchmatch
# which is causing CI tests to error out.
# Try loading patchmatch
Globals.try_patchmatch = True

44
ldm/invoke/patchmatch.py Normal file
View File

@ -0,0 +1,44 @@
'''
This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
'''
from ldm.invoke.globals import Globals
import numpy as np
class PatchMatch:
'''
Thin class wrapper around the patchmatch function.
'''
patch_match = None
tried_load:bool = False
def __init__(self):
super().__init__()
@classmethod
def _load_patch_match(self):
if self.tried_load:
return
if Globals.try_patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
print('>> Patchmatch initialized')
else:
print('>> Patchmatch not loaded (nonfatal)')
self.patch_match = pm
else:
print('>> Patchmatch loading disabled')
self.tried_load = True
@classmethod
def patchmatch_available(self)->bool:
self._load_patch_match()
return self.patch_match and self.patch_match.patchmatch_available
@classmethod
def inpaint(self,*args,**kwargs)->np.ndarray:
if self.patchmatch_available():
return self.patch_match.inpaint(*args,**kwargs)