mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into model-switching
This commit is contained in:
@ -34,6 +34,26 @@ from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||
@ -157,6 +177,7 @@ class Generate:
|
||||
self.precision = precision
|
||||
self.strength = 0.75
|
||||
self.seamless = False
|
||||
self.seamless_axes = {'x','y'}
|
||||
self.hires_fix = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
@ -172,6 +193,7 @@ class Generate:
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -244,6 +266,7 @@ class Generate:
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
seamless_axes = {'x','y'},
|
||||
log_tokenization = False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
@ -252,6 +275,7 @@ class Generate:
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
text_mask = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
@ -264,6 +288,8 @@ class Generate:
|
||||
codeformer_fidelity = None,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
# this is specific to inpainting and causes more extreme inpainting
|
||||
inpaint_replace = 0.0,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
@ -282,6 +308,8 @@ class Generate:
|
||||
seamless // whether the generated image should tile
|
||||
hires_fix // whether the Hires Fix should be applied during generation
|
||||
init_img // path to an initial image
|
||||
init_mask // path to a mask for the initial image
|
||||
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
@ -314,6 +342,7 @@ class Generate:
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
seamless_axes = seamless_axes or self.seamless_axes
|
||||
hires_fix = hires_fix or self.hires_fix
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
@ -331,10 +360,8 @@ class Generate:
|
||||
# to the width and height of the image training set
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||
@ -362,6 +389,7 @@ class Generate:
|
||||
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
|
||||
|
||||
width, height, _ = self._resolution_check(width, height, log=True)
|
||||
assert inpaint_replace >=0.0 and inpaint_replace <= 1.0,'inpaint_replace must be between 0.0 and 1.0'
|
||||
|
||||
if sampler_name and (sampler_name != self.sampler_name):
|
||||
self.sampler_name = sampler_name
|
||||
@ -388,7 +416,10 @@ class Generate:
|
||||
width,
|
||||
height,
|
||||
fit=fit,
|
||||
text_mask=text_mask,
|
||||
)
|
||||
|
||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
@ -403,6 +434,7 @@ class Generate:
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations
|
||||
)
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@ -424,6 +456,7 @@ class Generate:
|
||||
perlin=perlin,
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -599,17 +632,14 @@ class Generate:
|
||||
width,
|
||||
height,
|
||||
fit=False,
|
||||
text_mask=None,
|
||||
):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
if not img:
|
||||
return None, None
|
||||
|
||||
image = self._load_img(
|
||||
img,
|
||||
width,
|
||||
height,
|
||||
)
|
||||
image = self._load_img(img)
|
||||
|
||||
if image.width < self.width and image.height < self.height:
|
||||
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
||||
@ -627,10 +657,12 @@ class Generate:
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
|
||||
if mask:
|
||||
mask_image = self._load_img(
|
||||
mask, width, height) # this returns an Image
|
||||
mask_image = self._load_img(mask) # this returns an Image
|
||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
return init_image, init_mask
|
||||
|
||||
def _make_base(self):
|
||||
@ -699,7 +731,7 @@ class Generate:
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
|
||||
@ -776,6 +808,23 @@ class Generate:
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5):
|
||||
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
|
||||
basename,_ = os.path.splitext(os.path.basename(image_path))
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device)
|
||||
segmented = self.txt2mask.segment(image_path,prompt)
|
||||
trans = segmented.to_transparent()
|
||||
inverse = segmented.to_transparent(invert=True)
|
||||
mask = segmented.to_mask(threshold)
|
||||
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .')
|
||||
|
||||
callback(trans,f'{safe_prompt}.deselected',use_prefix=basename)
|
||||
callback(inverse,f'{safe_prompt}.selected',use_prefix=basename)
|
||||
callback(mask,f'{safe_prompt}.masked',use_prefix=basename)
|
||||
|
||||
# to help WebGUI - front end to generator util function
|
||||
def sample_to_image(self, samples):
|
||||
return self._make_base().sample_to_image(samples)
|
||||
@ -808,7 +857,7 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
def _load_img(self, img, width, height)->Image:
|
||||
def _load_img(self, img)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
print(
|
||||
@ -870,6 +919,29 @@ class Generate:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||
prompt = text_mask[0]
|
||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device)
|
||||
|
||||
segmented = self.txt2mask.segment(image, prompt)
|
||||
mask = segmented.to_mask(float(confidence_level))
|
||||
mask = mask.convert('RGB')
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
mask = self._fit_image(mask, (width, height))
|
||||
else:
|
||||
mask = self._squeeze_image(mask)
|
||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
mask = np.array(mask)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None].transpose(0, 3, 1, 2)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask.to(self.device)
|
||||
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user