mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into outpaint
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import pyparsing
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
@ -179,6 +180,7 @@ class Generate:
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
self.karras_max = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -269,10 +271,12 @@ class Generate:
|
||||
variation_amount = 0.0,
|
||||
threshold = 0.0,
|
||||
perlin = 0.0,
|
||||
karras_max = None,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
text_mask = None,
|
||||
invert_mask = False,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
@ -317,6 +321,7 @@ class Generate:
|
||||
init_img // path to an initial image
|
||||
init_mask // path to a mask for the initial image
|
||||
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||
invert_mask // boolean, if true invert the mask
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
@ -357,7 +362,8 @@ class Generate:
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
self.step_callback = step_callback
|
||||
self.karras_max = karras_max
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
@ -402,6 +408,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
if karras_max is not None and isinstance(self.sampler,KSampler):
|
||||
self.sampler.adjust_settings(karras_max=karras_max)
|
||||
|
||||
tic = time.time()
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@ -411,7 +422,7 @@ class Generate:
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
@ -424,19 +435,11 @@ class Generate:
|
||||
height,
|
||||
fit=fit,
|
||||
text_mask=text_mask,
|
||||
invert_mask=invert_mask,
|
||||
)
|
||||
|
||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
generator = self._make_embiggen()
|
||||
elif init_image is not None:
|
||||
generator = self._make_img2img()
|
||||
elif hires_fix:
|
||||
generator = self._make_txt2img2img()
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
|
||||
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations
|
||||
@ -455,7 +458,7 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
@ -494,14 +497,14 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
except KeyboardInterrupt:
|
||||
if catch_interrupts:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
@ -558,7 +561,7 @@ class Generate:
|
||||
# try to reuse the same filename prefix as the original file.
|
||||
# we take everything up to the first period
|
||||
prefix = None
|
||||
m = re.match('^([^.]+)\.',os.path.basename(image_path))
|
||||
m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
|
||||
if m:
|
||||
prefix = m.groups()[0]
|
||||
|
||||
@ -566,7 +569,8 @@ class Generate:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# used by multiple postfixers
|
||||
uc, c = get_uc_and_c(
|
||||
# todo: cross-attention control
|
||||
uc, c, _ = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
@ -611,10 +615,9 @@ class Generate:
|
||||
|
||||
elif tool == 'embiggen':
|
||||
# fetch the metadata from the image
|
||||
generator = self._make_embiggen()
|
||||
generator = self.select_generator(embiggen=True)
|
||||
opt.strength = 0.40
|
||||
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
|
||||
# embiggen takes a image path (sigh)
|
||||
generator.generate(
|
||||
prompt,
|
||||
sampler = self.sampler,
|
||||
@ -648,6 +651,31 @@ class Generate:
|
||||
print(f'* postprocessing tool {tool} is not yet supported')
|
||||
return None
|
||||
|
||||
def select_generator(
|
||||
self,
|
||||
init_image:Image.Image=None,
|
||||
mask_image:Image.Image=None,
|
||||
embiggen:bool=False,
|
||||
hires_fix:bool=False,
|
||||
):
|
||||
inpainting_model_in_use = self.sampler.uses_inpainting_model()
|
||||
|
||||
if hires_fix:
|
||||
return self._make_txt2img2img()
|
||||
|
||||
if embiggen is not None:
|
||||
return self._make_embiggen()
|
||||
|
||||
if inpainting_model_in_use:
|
||||
return self._make_omnibus()
|
||||
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
return self._make_inpaint()
|
||||
|
||||
if init_image is not None:
|
||||
return self._make_img2img()
|
||||
|
||||
return self._make_txt2img()
|
||||
|
||||
def _make_images(
|
||||
self,
|
||||
@ -657,6 +685,7 @@ class Generate:
|
||||
height,
|
||||
fit=False,
|
||||
text_mask=None,
|
||||
invert_mask=False,
|
||||
):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
@ -686,8 +715,12 @@ class Generate:
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
if invert_mask:
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.invoke.generator import Generator
|
||||
@ -698,6 +731,7 @@ class Generate:
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
@ -726,6 +760,15 @@ class Generate:
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
return self.generators['inpaint']
|
||||
|
||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||
# txt2img, img2img and inpainting using slight variations on the same code
|
||||
def _make_omnibus(self):
|
||||
if not self.generators.get('omnibus'):
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['omnibus']
|
||||
|
||||
def load_model(self):
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
@ -852,6 +895,8 @@ class Generate:
|
||||
def sample_to_image(self, samples):
|
||||
return self._make_base().sample_to_image(samples)
|
||||
|
||||
# very repetitive code - can this be simplified? The KSampler names are
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
@ -859,15 +904,11 @@ class Generate:
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
|
Reference in New Issue
Block a user