mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Bring main back into a consistent state with other branches
- Due to misuse of rebase command, main was transiently in an inconsistent state. - This repairs the damage, and adds a few post-release patches that ensure stable conda installs on Mac and Windows.
This commit is contained in:
365
ldm/generate.py
365
ldm/generate.py
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import pyparsing
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
@ -14,6 +14,7 @@ import sys
|
||||
import traceback
|
||||
import transformers
|
||||
import io
|
||||
import gc
|
||||
import hashlib
|
||||
import cv2
|
||||
import skimage
|
||||
@ -24,6 +25,7 @@ from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
@ -32,9 +34,11 @@ from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
@ -53,23 +57,8 @@ torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -123,12 +112,13 @@ still work.
|
||||
The full list of arguments to Generate() are:
|
||||
gr = Generate(
|
||||
# these values are set once and shouldn't be changed
|
||||
conf = path to configuration file ('configs/models.yaml')
|
||||
model = symbolic name of the model in the configuration file
|
||||
precision = float precision to be used
|
||||
conf:str = path to configuration file ('configs/models.yaml')
|
||||
model:str = symbolic name of the model in the configuration file
|
||||
precision:float = float precision to be used
|
||||
safety_checker:bool = activate safety checker [False]
|
||||
|
||||
# this value is sticky and maintained between generation calls
|
||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
|
||||
# these are deprecated - use conf and model instead
|
||||
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
||||
@ -145,23 +135,24 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = 'stable-diffusion-1.4',
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
safety_checker:bool=False,
|
||||
max_loaded_models:int=2,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
@ -173,6 +164,7 @@ class Generate:
|
||||
self.precision = precision
|
||||
self.strength = 0.75
|
||||
self.seamless = False
|
||||
self.seamless_axes = {'x','y'}
|
||||
self.hires_fix = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
@ -187,7 +179,11 @@ class Generate:
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.max_loaded_models = max_loaded_models,
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
self.karras_max = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -205,7 +201,8 @@ class Generate:
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
@ -214,6 +211,20 @@ class Generate:
|
||||
# gets rid of annoying messages about random seed
|
||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print('>> Initializing safety checker')
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_checker.to(self.device)
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
@ -258,14 +269,18 @@ class Generate:
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
seamless_axes = {'x','y'},
|
||||
log_tokenization = False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
threshold = 0.0,
|
||||
perlin = 0.0,
|
||||
karras_max = None,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
text_mask = None,
|
||||
invert_mask = False,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
@ -280,9 +295,19 @@ class Generate:
|
||||
upscale = None,
|
||||
# this is specific to inpainting and causes more extreme inpainting
|
||||
inpaint_replace = 0.0,
|
||||
# This will help match inpainted areas to the original image more smoothly
|
||||
mask_blur_radius: int = 8,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
use_mps_noise = False,
|
||||
# Seam settings for outpainting
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
force_outpaint: bool = False,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -298,6 +323,9 @@ class Generate:
|
||||
seamless // whether the generated image should tile
|
||||
hires_fix // whether the Hires Fix should be applied during generation
|
||||
init_img // path to an initial image
|
||||
init_mask // path to a mask for the initial image
|
||||
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||
invert_mask // boolean, if true invert the mask
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
@ -330,6 +358,7 @@ class Generate:
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
seamless_axes = seamless_axes or self.seamless_axes
|
||||
hires_fix = hires_fix or self.hires_fix
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
@ -337,7 +366,8 @@ class Generate:
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
self.step_callback = step_callback
|
||||
self.karras_max = karras_max
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
@ -347,10 +377,8 @@ class Generate:
|
||||
# to the width and height of the image training set
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||
@ -384,6 +412,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
if karras_max is not None and isinstance(self.sampler,KSampler):
|
||||
self.sampler.adjust_settings(karras_max=karras_max)
|
||||
|
||||
tic = time.time()
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@ -393,35 +426,35 @@ class Generate:
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
init_image,mask_image = self._make_images(
|
||||
init_image, mask_image = self._make_images(
|
||||
init_img,
|
||||
init_mask,
|
||||
width,
|
||||
height,
|
||||
fit=fit,
|
||||
text_mask=text_mask,
|
||||
invert_mask=invert_mask,
|
||||
force_outpaint=force_outpaint,
|
||||
)
|
||||
|
||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
generator = self._make_embiggen()
|
||||
elif init_image is not None:
|
||||
generator = self._make_img2img()
|
||||
elif hires_fix:
|
||||
generator = self._make_txt2img2img()
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
|
||||
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations
|
||||
)
|
||||
generator.use_mps_noise = use_mps_noise
|
||||
|
||||
checker = {
|
||||
'checker':self.safety_checker,
|
||||
'extractor':self.safety_feature_extractor
|
||||
} if self.safety_checker else None
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
@ -430,13 +463,13 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_image=init_image, # notice that init_image is different from init_img
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
@ -445,6 +478,14 @@ class Generate:
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius,
|
||||
safety_checker=checker,
|
||||
seam_size = seam_size,
|
||||
seam_blur = seam_blur,
|
||||
seam_strength = seam_strength,
|
||||
seam_steps = seam_steps,
|
||||
tile_size = tile_size,
|
||||
force_outpaint = force_outpaint
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -461,14 +502,14 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
except KeyboardInterrupt:
|
||||
if catch_interrupts:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
@ -525,7 +566,7 @@ class Generate:
|
||||
# try to reuse the same filename prefix as the original file.
|
||||
# we take everything up to the first period
|
||||
prefix = None
|
||||
m = re.match('^([^.]+)\.',os.path.basename(image_path))
|
||||
m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
|
||||
if m:
|
||||
prefix = m.groups()[0]
|
||||
|
||||
@ -533,7 +574,8 @@ class Generate:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# used by multiple postfixers
|
||||
uc, c = get_uc_and_c(
|
||||
# todo: cross-attention control
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
@ -562,30 +604,32 @@ class Generate:
|
||||
from ldm.invoke.restoration.outcrop import Outcrop
|
||||
extend_instructions = {}
|
||||
for direction,pixels in _pairwise(opt.outcrop):
|
||||
extend_instructions[direction]=int(pixels)
|
||||
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
try:
|
||||
extend_instructions[direction]=int(pixels)
|
||||
except ValueError:
|
||||
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||
if len(extend_instructions)>0:
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
|
||||
elif tool == 'embiggen':
|
||||
# fetch the metadata from the image
|
||||
generator = self._make_embiggen()
|
||||
generator = self.select_generator(embiggen=True)
|
||||
opt.strength = 0.40
|
||||
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
|
||||
# embiggen takes a image path (sigh)
|
||||
generator.generate(
|
||||
prompt,
|
||||
sampler = self.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.ddim_eta,
|
||||
conditioning= (uc, c),
|
||||
conditioning= (uc, c, extra_conditioning_info),
|
||||
init_img = image_path, # not the Image! (sigh)
|
||||
init_image = image, # embiggen wants both! (sigh)
|
||||
strength = opt.strength,
|
||||
@ -612,6 +656,32 @@ class Generate:
|
||||
print(f'* postprocessing tool {tool} is not yet supported')
|
||||
return None
|
||||
|
||||
def select_generator(
|
||||
self,
|
||||
init_image:Image.Image=None,
|
||||
mask_image:Image.Image=None,
|
||||
embiggen:bool=False,
|
||||
hires_fix:bool=False,
|
||||
force_outpaint:bool=False,
|
||||
):
|
||||
inpainting_model_in_use = self.sampler.uses_inpainting_model()
|
||||
|
||||
if hires_fix:
|
||||
return self._make_txt2img2img()
|
||||
|
||||
if embiggen is not None:
|
||||
return self._make_embiggen()
|
||||
|
||||
if inpainting_model_in_use:
|
||||
return self._make_omnibus()
|
||||
|
||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||
return self._make_inpaint()
|
||||
|
||||
if init_image is not None:
|
||||
return self._make_img2img()
|
||||
|
||||
return self._make_txt2img()
|
||||
|
||||
def _make_images(
|
||||
self,
|
||||
@ -620,40 +690,44 @@ class Generate:
|
||||
width,
|
||||
height,
|
||||
fit=False,
|
||||
text_mask=None,
|
||||
invert_mask=False,
|
||||
force_outpaint=False,
|
||||
):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
if not img:
|
||||
return None, None
|
||||
|
||||
image = self._load_img(
|
||||
img,
|
||||
width,
|
||||
height,
|
||||
)
|
||||
image = self._load_img(img)
|
||||
|
||||
if image.width < self.width and image.height < self.height:
|
||||
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
||||
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask)
|
||||
# this returns a torch tensor
|
||||
self._transparency_check_and_warning(image, mask, force_outpaint)
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
init_image = self._create_init_image(image,width,height,fit=fit)
|
||||
|
||||
if mask:
|
||||
mask_image = self._load_img(
|
||||
mask, width, height) # this returns an Image
|
||||
mask_image = self._load_img(mask)
|
||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||
|
||||
return init_image, init_mask
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
if invert_mask:
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.invoke.generator import Generator
|
||||
@ -664,6 +738,7 @@ class Generate:
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
@ -692,6 +767,15 @@ class Generate:
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
return self.generators['inpaint']
|
||||
|
||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||
# txt2img, img2img and inpainting using slight variations on the same code
|
||||
def _make_omnibus(self):
|
||||
if not self.generators.get('omnibus'):
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['omnibus']
|
||||
|
||||
def load_model(self):
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
@ -706,10 +790,20 @@ class Generate:
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_cache
|
||||
cache.print_vram_usage()
|
||||
|
||||
# have to get rid of all references to model in order
|
||||
# to free it from GPU memory
|
||||
self.model = None
|
||||
self.sampler = None
|
||||
self.generators = {}
|
||||
gc.collect()
|
||||
|
||||
model_data = cache.get_model(model_name)
|
||||
if model_data is None: # restore previous
|
||||
model_data = cache.get_model(self.model_name)
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
@ -721,7 +815,7 @@ class Generate:
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
|
||||
@ -798,10 +892,32 @@ class Generate:
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5):
|
||||
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
|
||||
basename,_ = os.path.splitext(os.path.basename(image_path))
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device, refined=True)
|
||||
segmented = self.txt2mask.segment(image_path,prompt)
|
||||
trans = segmented.to_transparent()
|
||||
inverse = segmented.to_transparent(invert=True)
|
||||
mask = segmented.to_mask(threshold)
|
||||
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .')
|
||||
|
||||
callback(trans,f'{safe_prompt}.deselected',use_prefix=basename)
|
||||
callback(inverse,f'{safe_prompt}.selected',use_prefix=basename)
|
||||
callback(mask,f'{safe_prompt}.masked',use_prefix=basename)
|
||||
|
||||
# to help WebGUI - front end to generator util function
|
||||
def sample_to_image(self, samples):
|
||||
return self._make_base().sample_to_image(samples)
|
||||
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
return self._make_base().sample_to_lowres_estimated_image(samples)
|
||||
|
||||
# very repetitive code - can this be simplified? The KSampler names are
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
@ -809,15 +925,11 @@ class Generate:
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
@ -830,7 +942,7 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
def _load_img(self, img, width, height)->Image:
|
||||
def _load_img(self, img)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
print(
|
||||
@ -851,47 +963,48 @@ class Generate:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
return image
|
||||
|
||||
def _create_init_image(self, image, width, height, fit=True):
|
||||
image = image.convert('RGB')
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.device)
|
||||
def _create_init_image(self, image: Image.Image, width, height, fit=True):
|
||||
if image.mode != 'RGBA':
|
||||
image = image.convert('RGBA')
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
def _create_init_mask(self, image, width, height, fit=True):
|
||||
# convert into a black/white mask
|
||||
image = self._image_to_mask(image)
|
||||
image = image.convert('RGB')
|
||||
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = image.resize((image.width//downsampling, image.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image.to(self.device)
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
# The mask is expected to have the region to be inpainted
|
||||
# with alpha transparency. It converts it into a black/white
|
||||
# image with the transparent part black.
|
||||
def _image_to_mask(self, mask_image, invert=False) -> Image:
|
||||
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if mask_image.mode == 'L':
|
||||
mask = mask_image
|
||||
elif mask_image.mode in ('RGB', 'P'):
|
||||
mask = mask_image.convert('L')
|
||||
else:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if invert:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||
prompt = text_mask[0]
|
||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device)
|
||||
|
||||
segmented = self.txt2mask.segment(image, prompt)
|
||||
mask = segmented.to_mask(float(confidence_level))
|
||||
mask = mask.convert('RGB')
|
||||
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
||||
return mask
|
||||
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
return True
|
||||
@ -919,11 +1032,11 @@ class Generate:
|
||||
colored += 1
|
||||
return colored == 0
|
||||
|
||||
def _transparency_check_and_warning(self,image, mask):
|
||||
def _transparency_check_and_warning(self,image, mask, force_outpaint=False):
|
||||
if not mask:
|
||||
print(
|
||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
if self._check_for_erasure(image):
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
print(
|
||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||
|
Reference in New Issue
Block a user