mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
defer patchmatch loading (#2039)
* defer patchmatch loading Because of the way that patchmatch was loaded early at import time, it was impossible to turn off the attempted loading with --no-patchmatch. In addition, the patchmatch loading messages appear early on during initialization, interfering with ability to print out the version cleanly when --version provided to invoke script. This commit creates a thin wrapper class for patch_match that is only loaded when needed, solving both problems. * create a singleton patchmatch object for use in inpainting This creates a thin wrapper to patchmatch which loads the module on demand, respecting the global "trypatchmatch" option. * address 2d round of issues in PR 2039 comments * Patchmatch->PatchMatch and misc cleanup
This commit is contained in:
@ -294,7 +294,7 @@ class InvokeAIWebServer:
|
|||||||
print(f">> System config requested")
|
print(f">> System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_cache.list_models()
|
config["model_list"] = self.generate.model_cache.list_models()
|
||||||
config["infill_methods"] = infill_methods
|
config["infill_methods"] = infill_methods()
|
||||||
socketio.emit("systemConfig", config)
|
socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
|
@ -190,6 +190,7 @@ class Generate:
|
|||||||
self.txt2mask = None
|
self.txt2mask = None
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
self.karras_max = None
|
self.karras_max = None
|
||||||
|
self.infill_method = None
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -326,7 +327,7 @@ class Generate:
|
|||||||
seam_strength: float = 0.7,
|
seam_strength: float = 0.7,
|
||||||
seam_steps: int = 10,
|
seam_steps: int = 10,
|
||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
infill_method = infill_methods[0], # The infill method to use
|
infill_method = None,
|
||||||
force_outpaint: bool = False,
|
force_outpaint: bool = False,
|
||||||
enable_image_debugging = False,
|
enable_image_debugging = False,
|
||||||
|
|
||||||
@ -392,6 +393,7 @@ class Generate:
|
|||||||
self.log_tokenization = log_tokenization
|
self.log_tokenization = log_tokenization
|
||||||
self.step_callback = step_callback
|
self.step_callback = step_callback
|
||||||
self.karras_max = karras_max
|
self.karras_max = karras_max
|
||||||
|
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
|
||||||
with_variations = [] if with_variations is None else with_variations
|
with_variations = [] if with_variations is None else with_variations
|
||||||
|
|
||||||
# will instantiate the model or return it from cache
|
# will instantiate the model or return it from cache
|
||||||
|
@ -17,22 +17,15 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
from ldm.util import debug_image
|
from ldm.util import debug_image
|
||||||
|
from ldm.invoke.patchmatch import PatchMatch
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
infill_methods: list[str] = list()
|
def infill_methods()->list[str]:
|
||||||
|
methods = list()
|
||||||
if Globals.try_patchmatch:
|
if PatchMatch.patchmatch_available():
|
||||||
from patchmatch import patch_match
|
methods.append('patchmatch')
|
||||||
if patch_match.patchmatch_available:
|
methods.append('tile')
|
||||||
print('>> Patchmatch initialized')
|
return methods
|
||||||
infill_methods.append('patchmatch')
|
|
||||||
else:
|
|
||||||
print('>> Patchmatch not loaded (nonfatal)')
|
|
||||||
else:
|
|
||||||
print('>> Patchmatch loading disabled')
|
|
||||||
|
|
||||||
infill_methods.append('tile')
|
|
||||||
|
|
||||||
|
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -40,6 +33,7 @@ class Inpaint(Img2Img):
|
|||||||
self.pil_image = None
|
self.pil_image = None
|
||||||
self.pil_mask = None
|
self.pil_mask = None
|
||||||
self.mask_blur_radius = 0
|
self.mask_blur_radius = 0
|
||||||
|
self.infill_method = None
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
|
|
||||||
# Outpaint support code
|
# Outpaint support code
|
||||||
@ -64,11 +58,11 @@ class Inpaint(Img2Img):
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Skip patchmatch if patchmatch isn't available
|
# Skip patchmatch if patchmatch isn't available
|
||||||
if not patch_match.patchmatch_available:
|
if not PatchMatch.patchmatch_available():
|
||||||
return im
|
return im
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
||||||
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
@ -187,7 +181,7 @@ class Inpaint(Img2Img):
|
|||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
step_callback=None,
|
step_callback=None,
|
||||||
inpaint_replace=False, enable_image_debugging=False,
|
inpaint_replace=False, enable_image_debugging=False,
|
||||||
infill_method = infill_methods[0], # The infill method to use
|
infill_method = None,
|
||||||
inpaint_width=None,
|
inpaint_width=None,
|
||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -198,6 +192,7 @@ class Inpaint(Img2Img):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.enable_image_debugging = enable_image_debugging
|
self.enable_image_debugging = enable_image_debugging
|
||||||
|
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
|
||||||
|
|
||||||
self.inpaint_width = inpaint_width
|
self.inpaint_width = inpaint_width
|
||||||
self.inpaint_height = inpaint_height
|
self.inpaint_height = inpaint_height
|
||||||
@ -206,7 +201,7 @@ class Inpaint(Img2Img):
|
|||||||
self.pil_image = init_image.copy()
|
self.pil_image = init_image.copy()
|
||||||
|
|
||||||
# Do infill
|
# Do infill
|
||||||
if infill_method == 'patchmatch' and patch_match.patchmatch_available:
|
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
|
||||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||||
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
|
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
|
||||||
init_filled = self.tile_fill_missing(
|
init_filled = self.tile_fill_missing(
|
||||||
|
@ -27,6 +27,5 @@ else:
|
|||||||
# Where to look for the initialization file
|
# Where to look for the initialization file
|
||||||
Globals.initfile = 'invokeai.init'
|
Globals.initfile = 'invokeai.init'
|
||||||
|
|
||||||
# Awkward workaround to disable attempted loading of pypatchmatch
|
# Try loading patchmatch
|
||||||
# which is causing CI tests to error out.
|
|
||||||
Globals.try_patchmatch = True
|
Globals.try_patchmatch = True
|
||||||
|
44
ldm/invoke/patchmatch.py
Normal file
44
ldm/invoke/patchmatch.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
'''
|
||||||
|
This module defines a singleton object, "patchmatch" that
|
||||||
|
wraps the actual patchmatch object. It respects the global
|
||||||
|
"try_patchmatch" attribute, so that patchmatch loading can
|
||||||
|
be suppressed or deferred
|
||||||
|
'''
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class PatchMatch:
|
||||||
|
'''
|
||||||
|
Thin class wrapper around the patchmatch function.
|
||||||
|
'''
|
||||||
|
|
||||||
|
patch_match = None
|
||||||
|
tried_load:bool = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_patch_match(self):
|
||||||
|
if self.tried_load:
|
||||||
|
return
|
||||||
|
if Globals.try_patchmatch:
|
||||||
|
from patchmatch import patch_match as pm
|
||||||
|
if pm.patchmatch_available:
|
||||||
|
print('>> Patchmatch initialized')
|
||||||
|
else:
|
||||||
|
print('>> Patchmatch not loaded (nonfatal)')
|
||||||
|
self.patch_match = pm
|
||||||
|
else:
|
||||||
|
print('>> Patchmatch loading disabled')
|
||||||
|
self.tried_load = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patchmatch_available(self)->bool:
|
||||||
|
self._load_patch_match()
|
||||||
|
return self.patch_match and self.patch_match.patchmatch_available
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inpaint(self,*args,**kwargs)->np.ndarray:
|
||||||
|
if self.patchmatch_available():
|
||||||
|
return self.patch_match.inpaint(*args,**kwargs)
|
Reference in New Issue
Block a user