Bring main back into a consistent state with other branches

- Due to misuse of rebase command, main was transiently
  in an inconsistent state.

- This repairs the damage, and adds a few post-release
  patches that ensure stable conda installs on Mac and Windows.
This commit is contained in:
Lincoln Stein
2022-11-03 15:44:06 -04:00
370 changed files with 32709 additions and 9854 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import pyparsing
# Derived from source code carrying the following copyrights
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
@ -14,6 +14,7 @@ import sys
import traceback
import transformers
import io
import gc
import hashlib
import cv2
import skimage
@ -24,6 +25,7 @@ from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
from ldm.invoke.prompt_parser import PromptParser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -32,9 +34,11 @@ from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
@ -53,23 +57,8 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
"""Simplified text to image API for stable diffusion/latent diffusion
@ -123,12 +112,13 @@ still work.
The full list of arguments to Generate() are:
gr = Generate(
# these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file
precision = float precision to be used
conf:str = path to configuration file ('configs/models.yaml')
model:str = symbolic name of the model in the configuration file
precision:float = float precision to be used
safety_checker:bool = activate safety checker [False]
# this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
# these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
@ -145,23 +135,24 @@ class Generate:
def __init__(
self,
model = 'stable-diffusion-1.4',
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
max_loaded_models:int=2,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None
self.width = None
self.model_cache = None
@ -173,6 +164,7 @@ class Generate:
self.precision = precision
self.strength = 0.75
self.seamless = False
self.seamless_axes = {'x','y'}
self.hires_fix = False
self.embedding_path = embedding_path
self.model = None # empty for now
@ -187,7 +179,11 @@ class Generate:
self.codeformer = codeformer
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = max_loaded_models,
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
self.karras_max = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -205,7 +201,8 @@ class Generate:
self.precision = choose_precision(self.device)
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@ -214,6 +211,20 @@ class Generate:
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
self.safety_checker.to(self.device)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
@ -258,14 +269,18 @@ class Generate:
height = None,
sampler_name = None,
seamless = False,
seamless_axes = {'x','y'},
log_tokenization = False,
with_variations = None,
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
karras_max = None,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
text_mask = None,
invert_mask = False,
fit = False,
strength = None,
init_color = None,
@ -280,9 +295,19 @@ class Generate:
upscale = None,
# this is specific to inpainting and causes more extreme inpainting
inpaint_replace = 0.0,
# This will help match inpainted areas to the original image more smoothly
mask_blur_radius: int = 8,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
use_mps_noise = False,
# Seam settings for outpainting
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
force_outpaint: bool = False,
**args,
): # eat up additional cruft
"""
@ -298,6 +323,9 @@ class Generate:
seamless // whether the generated image should tile
hires_fix // whether the Hires Fix should be applied during generation
init_img // path to an initial image
init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask
invert_mask // boolean, if true invert the mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
@ -330,6 +358,7 @@ class Generate:
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
seamless_axes = seamless_axes or self.seamless_axes
hires_fix = hires_fix or self.hires_fix
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
@ -337,7 +366,8 @@ class Generate:
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
self.step_callback = step_callback
self.karras_max = karras_max
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
@ -347,10 +377,8 @@ class Generate:
# to the width and height of the image training set
width = width or self.width
height = height or self.height
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
configure_model_padding(model, seamless, seamless_axes)
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
@ -384,6 +412,11 @@ class Generate:
self.sampler_name = sampler_name
self._set_sampler()
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for
if karras_max is not None and isinstance(self.sampler,KSampler):
self.sampler.adjust_settings(karras_max=karras_max)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
@ -393,35 +426,35 @@ class Generate:
mask_image = None
try:
uc, c = get_uc_and_c(
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
log_tokens =self.log_tokenization
)
init_image,mask_image = self._make_images(
init_image, mask_image = self._make_images(
init_img,
init_mask,
width,
height,
fit=fit,
text_mask=text_mask,
invert_mask=invert_mask,
force_outpaint=force_outpaint,
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen()
elif init_image is not None:
generator = self._make_img2img()
elif hires_fix:
generator = self._make_txt2img2img()
else:
generator = self._make_txt2img()
generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
generator.set_variation(
self.seed, variation_amount, with_variations
)
generator.use_mps_noise = use_mps_noise
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate(
prompt,
@ -430,13 +463,13 @@ class Generate:
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
@ -445,6 +478,14 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=checker,
seam_size = seam_size,
seam_blur = seam_blur,
seam_strength = seam_strength,
seam_steps = seam_steps,
tile_size = tile_size,
force_outpaint = force_outpaint
)
if init_color:
@ -461,14 +502,14 @@ class Generate:
save_original = save_original,
image_callback = image_callback)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time()
print('>> Usage stats:')
@ -525,7 +566,7 @@ class Generate:
# try to reuse the same filename prefix as the original file.
# we take everything up to the first period
prefix = None
m = re.match('^([^.]+)\.',os.path.basename(image_path))
m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
if m:
prefix = m.groups()[0]
@ -533,7 +574,8 @@ class Generate:
image = Image.open(image_path)
# used by multiple postfixers
uc, c = get_uc_and_c(
# todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
@ -562,30 +604,32 @@ class Generate:
from ldm.invoke.restoration.outcrop import Outcrop
extend_instructions = {}
for direction,pixels in _pairwise(opt.outcrop):
extend_instructions[direction]=int(pixels)
restorer = Outcrop(image,self,)
return restorer.process (
extend_instructions,
opt = opt,
orig_opt = args,
image_callback = callback,
prefix = prefix,
)
try:
extend_instructions[direction]=int(pixels)
except ValueError:
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
if len(extend_instructions)>0:
restorer = Outcrop(image,self,)
return restorer.process (
extend_instructions,
opt = opt,
orig_opt = args,
image_callback = callback,
prefix = prefix,
)
elif tool == 'embiggen':
# fetch the metadata from the image
generator = self._make_embiggen()
generator = self.select_generator(embiggen=True)
opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
# embiggen takes a image path (sigh)
generator.generate(
prompt,
sampler = self.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.ddim_eta,
conditioning= (uc, c),
conditioning= (uc, c, extra_conditioning_info),
init_img = image_path, # not the Image! (sigh)
init_image = image, # embiggen wants both! (sigh)
strength = opt.strength,
@ -612,6 +656,32 @@ class Generate:
print(f'* postprocessing tool {tool} is not yet supported')
return None
def select_generator(
self,
init_image:Image.Image=None,
mask_image:Image.Image=None,
embiggen:bool=False,
hires_fix:bool=False,
force_outpaint:bool=False,
):
inpainting_model_in_use = self.sampler.uses_inpainting_model()
if hires_fix:
return self._make_txt2img2img()
if embiggen is not None:
return self._make_embiggen()
if inpainting_model_in_use:
return self._make_omnibus()
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint()
if init_image is not None:
return self._make_img2img()
return self._make_txt2img()
def _make_images(
self,
@ -620,40 +690,44 @@ class Generate:
width,
height,
fit=False,
text_mask=None,
invert_mask=False,
force_outpaint=False,
):
init_image = None
init_mask = None
if not img:
return None, None
image = self._load_img(
img,
width,
height,
)
image = self._load_img(img)
if image.width < self.width and image.height < self.height:
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image):
self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
self._transparency_check_and_warning(image, mask, force_outpaint)
init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
init_image = self._create_init_image(image,width,height,fit=fit)
if mask:
mask_image = self._load_img(
mask, width, height) # this returns an Image
mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
return init_image, init_mask
elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
if invert_mask:
init_mask = ImageOps.invert(init_mask)
return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self):
if not self.generators.get('base'):
from ldm.invoke.generator import Generator
@ -664,6 +738,7 @@ class Generate:
if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img']
def _make_embiggen(self):
@ -692,6 +767,15 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus']
def load_model(self):
'''
preload model identified in self.model_name
@ -706,10 +790,20 @@ class Generate:
if self.model_name == model_name and self.model is not None:
return self.model
model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **')
return self.model
# the model cache does the loading and offloading
cache = self.model_cache
cache.print_vram_usage()
# have to get rid of all references to model in order
# to free it from GPU memory
self.model = None
self.sampler = None
self.generators = {}
gc.collect()
model_data = cache.get_model(model_name)
if model_data is None: # restore previous
model_data = cache.get_model(self.model_name)
self.model = model_data['model']
self.width = model_data['width']
@ -721,7 +815,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None:
model.embedding_manager.load(
self.model.embedding_manager.load(
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
)
@ -798,10 +892,32 @@ class Generate:
else:
r[0] = image
def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5):
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
basename,_ = os.path.splitext(os.path.basename(image_path))
if self.txt2mask is None:
self.txt2mask = Txt2Mask(device = self.device, refined=True)
segmented = self.txt2mask.segment(image_path,prompt)
trans = segmented.to_transparent()
inverse = segmented.to_transparent(invert=True)
mask = segmented.to_mask(threshold)
path_filter = re.compile(r'[<>:"/\\|?*]')
safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .')
callback(trans,f'{safe_prompt}.deselected',use_prefix=basename)
callback(inverse,f'{safe_prompt}.selected',use_prefix=basename)
callback(mask,f'{safe_prompt}.masked',use_prefix=basename)
# to help WebGUI - front end to generator util function
def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples)
def sample_to_lowres_estimated_image(self, samples):
return self._make_base().sample_to_lowres_estimated_image(samples)
# very repetitive code - can this be simplified? The KSampler names are
# consistent, at least
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
@ -809,15 +925,11 @@ class Generate:
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun':
@ -830,7 +942,7 @@ class Generate:
print(msg)
def _load_img(self, img, width, height)->Image:
def _load_img(self, img)->Image:
if isinstance(img, Image.Image):
image = img
print(
@ -851,47 +963,48 @@ class Generate:
image = ImageOps.exif_transpose(image)
return image
def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB')
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_image(self, image: Image.Image, width, height, fit=True):
if image.mode != 'RGBA':
image = image.convert('RGBA')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image
def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask
image = self._image_to_mask(image)
image = image.convert('RGB')
# now we adjust the size
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image
# The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white
# image with the transparent part black.
def _image_to_mask(self, mask_image, invert=False) -> Image:
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if mask_image.mode == 'L':
mask = mask_image
elif mask_image.mode in ('RGB', 'P'):
mask = mask_image.convert('L')
else:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if invert:
mask = ImageOps.invert(mask)
return mask
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
if self.txt2mask is None:
self.txt2mask = Txt2Mask(device = self.device)
segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB')
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
return mask
def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:
return True
@ -919,11 +1032,11 @@ class Generate:
colored += 1
return colored == 0
def _transparency_check_and_warning(self,image, mask):
def _transparency_check_and_warning(self,image, mask, force_outpaint=False):
if not mask:
print(
'>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image):
if (not force_outpaint) and self._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',

View File

@ -83,16 +83,16 @@ with metadata_from_png():
import argparse
from argparse import Namespace, RawTextHelpFormatter
import pydoc
import shlex
import json
import hashlib
import os
import re
import shlex
import copy
import base64
import functools
import ldm.invoke.pngwriter
from ldm.invoke.conditioning import split_weighted_subprompts
from ldm.invoke.prompt_parser import split_weighted_subprompts
SAMPLER_CHOICES = [
'ddim',
@ -113,8 +113,8 @@ PRECISION_CHOICES = [
]
# is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion'
APP_VERSION = 'v1.15'
APP_ID = 'invoke-ai/InvokeAI'
APP_VERSION = 'v2.02'
class ArgFormatter(argparse.RawTextHelpFormatter):
# use defined argument order to display usage
@ -169,27 +169,31 @@ class Args(object):
def parse_cmd(self,cmd_string):
'''Parse a invoke>-style command string '''
command = cmd_string.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError:
import sys, traceback
print(traceback.format_exc(), file=sys.stderr)
return
switches = ['']
switches_started = False
for element in elements:
if element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)
# handle the case in which the first token is a switch
if cmd_string.startswith('-'):
prompt = ''
switches = cmd_string
# handle the case in which the prompt is enclosed by quotes
elif cmd_string.startswith('"'):
a = shlex.split(cmd_string,comments=True)
prompt = a[0]
switches = shlex.join(a[1:])
else:
# no initial quote, so get everything up to the first thing
# that looks like a switch
if cmd_string.startswith('-'):
prompt = ''
switches = cmd_string
else:
switches[0] += element
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string)
if match:
prompt,switches = match.groups()
else:
prompt = cmd_string
switches = ''
try:
self._cmd_switches = self._cmd_parser.parse_args(switches)
self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches,comments=True))
setattr(self._cmd_switches,'prompt',prompt)
return self._cmd_switches
except:
return None
@ -210,12 +214,16 @@ class Args(object):
a = vars(self)
a.update(kwargs)
switches = list()
switches.append(f'"{a["prompt"]}"')
prompt = a['prompt']
prompt.replace('"','\\"')
switches.append(prompt)
switches.append(f'-s {a["steps"]}')
switches.append(f'-S {a["seed"]}')
switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}')
if a['karras_max'] is not None:
switches.append(f'--karras_max {a["karras_max"]}')
if a['perlin'] > 0:
switches.append(f'--perlin {a["perlin"]}')
if a['threshold'] > 0:
@ -241,6 +249,8 @@ class Args(object):
switches.append(f'-f {a["strength"]}')
if a['inpaint_replace']:
switches.append(f'--inpaint_replace')
if a['text_mask']:
switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}')
else:
switches.append(f'-A {a["sampler_name"]}')
@ -366,17 +376,16 @@ class Args(object):
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument(
'--conf',
'--config',
'-c',
'-conf',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--png_compression','-z',
@ -404,6 +413,13 @@ class Args(object):
action='store_true',
help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
)
model_group.add_argument(
'--free_gpu_mem',
dest='free_gpu_mem',
@ -419,6 +435,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument(
'--from_file',
dest='infile',
@ -438,6 +459,12 @@ class Args(object):
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument(
'--grid',
'-g',
@ -529,7 +556,7 @@ class Args(object):
formatter_class=ArgFormatter,
description=
"""
*Image generation:*
*Image generation*
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
*postprocessing*
@ -544,14 +571,28 @@ class Args(object):
!history lists all the commands issued during the current session.
!NN retrieves the NNth command from the history
*Model manipulation*
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model path/to/weights/file.ckpt -- adds a model to your config
!edit_model <model_name> -- edit a model's description
!del_model <model_name> -- delete a model
"""
)
render_group = parser.add_argument_group('General rendering')
img2img_group = parser.add_argument_group('Image-to-image and inpainting')
inpainting_group = parser.add_argument_group('Inpainting')
outpainting_group = parser.add_argument_group('Outpainting and outcropping')
variation_group = parser.add_argument_group('Creating and combining variations')
postprocessing_group = parser.add_argument_group('Post-processing')
special_effects_group = parser.add_argument_group('Special effects')
render_group.add_argument('prompt')
deprecated_group = parser.add_argument_group('Deprecated options')
render_group.add_argument(
'--prompt',
default='',
help='prompt string',
)
render_group.add_argument(
'-s',
'--steps',
@ -604,6 +645,12 @@ class Args(object):
type=float,
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument(
'--grid',
'-g',
@ -663,7 +710,13 @@ class Args(object):
default=6,
choices=range(0,10),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
help='level of PNG compression, from 0 (none) to 9 (maximum). [6]'
)
render_group.add_argument(
'--karras_max',
type=int,
default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
)
img2img_group.add_argument(
'-I',
@ -672,10 +725,12 @@ class Args(object):
help='Path to input image for img2img mode (supersedes width and height)',
)
img2img_group.add_argument(
'-M',
'--init_mask',
'-tm',
'--text_mask',
nargs='+',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).',
default=None,
)
img2img_group.add_argument(
'--init_color',
@ -696,29 +751,68 @@ class Args(object):
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
default=0.75,
)
img2img_group.add_argument(
'-D',
'--out_direction',
nargs='+',
inpainting_group.add_argument(
'-M',
'--init_mask',
type=str,
metavar=('direction', 'pixels'),
help='Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'-c',
'--outcrop',
nargs='+',
type=str,
metavar=('direction','pixels'),
help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64',
inpainting_group.add_argument(
'--invert_mask',
action='store_true',
help='Invert the mask',
)
img2img_group.add_argument(
inpainting_group.add_argument(
'-r',
'--inpaint_replace',
type=float,
default=0.0,
help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)',
)
outpainting_group.add_argument(
'-c',
'--outcrop',
nargs='+',
type=str,
metavar=('direction','pixels'),
help='Outcrop the image with one or more direction/pixel pairs: e.g. -c top 64 bottom 128 left 64 right 64',
)
outpainting_group.add_argument(
'--force_outpaint',
action='store_true',
default=False,
help='Force outpainting if you have no inpainting mask to pass',
)
outpainting_group.add_argument(
'--seam_size',
type=int,
default=0,
help='When outpainting, size of the mask around the seam between original and outpainted image',
)
outpainting_group.add_argument(
'--seam_blur',
type=int,
default=0,
help='When outpainting, the amount to blur the seam inwards',
)
outpainting_group.add_argument(
'--seam_strength',
type=float,
default=0.7,
help='When outpainting, the img2img strength to use when filling the seam. Values around 0.7 work well',
)
outpainting_group.add_argument(
'--seam_steps',
type=int,
default=10,
help='When outpainting, the number of steps to use to fill the seam. Low values (~10) work well',
)
outpainting_group.add_argument(
'--tile_size',
type=int,
default=32,
help='When outpainting, the tile size to use for filling outpaint areas',
)
postprocessing_group.add_argument(
'-ft',
'--facetool',
@ -776,6 +870,12 @@ class Args(object):
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
special_effects_group.add_argument(
'--seamless_axes',
default=['x', 'y'],
type=list[str],
help='Specify which axes to use circular convolution on.',
)
variation_group.add_argument(
'-v',
'--variation_amount',
@ -790,6 +890,20 @@ class Args(object):
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
render_group.add_argument(
'--use_mps_noise',
action='store_true',
dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results'
)
deprecated_group.add_argument(
'-D',
'--out_direction',
nargs='+',
type=str,
metavar=('direction', 'pixels'),
help='Older outcropping system. Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
)
return parser
def format_metadata(**kwargs):
@ -825,9 +939,8 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
'init_img','init_mask']
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless'
'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={}
for item in image_dict.items():
@ -877,6 +990,23 @@ def metadata_dumps(opt,
return metadata
@functools.lru_cache(maxsize=50)
def args_from_png(png_file_path) -> list[Args]:
'''
Given the path to a PNG file created by invoke.py,
retrieves a list of Args objects containing the image
data.
'''
try:
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
except AttributeError:
return [legacy_metadata_load({},png_file_path)]
try:
return metadata_loads(meta)
except:
return [legacy_metadata_load(meta,png_file_path)]
@functools.lru_cache(maxsize=50)
def metadata_from_png(png_file_path) -> Args:
'''
@ -884,11 +1014,8 @@ def metadata_from_png(png_file_path) -> Args:
an Args object containing the image metadata. Note that this
returns a single Args object, not multiple.
'''
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 :
return metadata_loads(meta)[0]
else:
return legacy_metadata_load(meta,png_file_path)
args_list = args_from_png(png_file_path)
return args_list[0]
def dream_cmd_from_png(png_file_path):
opt = metadata_from_png(png_file_path)
@ -903,14 +1030,14 @@ def metadata_loads(metadata) -> list:
'''
results = []
try:
if 'grid' in metadata['sd-metadata']:
if 'images' in metadata['sd-metadata']:
images = metadata['sd-metadata']['images']
else:
images = [metadata['sd-metadata']['image']]
for image in images:
# repack the prompt and variations
if 'prompt' in image:
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['prompt'] = repack_prompt(image['prompt'])
if 'variations' in image:
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
# fix a bit of semantic drift here
@ -918,12 +1045,19 @@ def metadata_loads(metadata) -> list:
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
except Exception as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print('>> could not read metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
def repack_prompt(prompt_list:list)->str:
# in the common case of no weighting syntax, just return the prompt as is
if len(prompt_list) > 1:
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
else:
return prompt_list[0]['prompt']
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):
@ -953,17 +1087,17 @@ def sha256(path):
return sha.hexdigest()
def legacy_metadata_load(meta,pathname) -> Args:
opt = Args()
if 'Dream' in meta and len(meta['Dream']) > 0:
dream_prompt = meta['Dream']
opt = Args()
opt.parse_cmd(dream_prompt)
return opt
else: # if nothing else, we can get the seed
match = re.search('\d+\.(\d+)',pathname)
if match:
seed = match.groups()[0]
opt = Args()
opt.seed = seed
return opt
return None
else:
opt.prompt = ''
opt.seed = 0
return opt

View File

@ -1,110 +1,195 @@
'''
This module handles the generation of the conditioning tensors, including management of
weighted subprompts.
This module handles the generation of the conditioning tensors.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
'''
import re
from difflib import SequenceMatcher
from typing import Union
import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt)
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
uc = model.get_learned_conditioning([unconditioned_words])
pp = PromptParser()
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
prompt, skip_normalize
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for i,flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
)
)
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
def build_token_edit_opcodes(original_tokens, edited_tokens):
original_tokens = original_tokens.cpu().numpy()[0]
edited_tokens = edited_tokens.cpu().numpy()[0]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
assert isinstance(uncond, dict)
assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond:
if isinstance(cond[k], list):
cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened

View File

@ -6,26 +6,31 @@ import torch
import numpy as np
import random
import os
import traceback
from tqdm import tqdm, trange
from PIL import Image
from PIL import Image, ImageFilter
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator():
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@ -40,12 +45,15 @@ class Generator():
self.variation_amount = variation_amount
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
sampler = sampler,
init_image = init_image,
width = width,
height = height,
@ -54,12 +62,14 @@ class Generator():
perlin = perlin,
**kwargs
)
results = []
seed = seed if seed is not None else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'):
x_T = None
if self.variation_amount > 0:
@ -74,19 +84,27 @@ class Generator():
try:
x_T = self.get_noise(width,height)
except:
pass
print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
def sample_to_image(self,samples)->Image.Image:
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
Given samples returned from a sampler, converts
it into a PIL Image
"""
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -98,6 +116,29 @@ class Generator():
)
return Image.fromarray(x_sample.astype(np.uint8))
# write an approximate RGB image from latent samples for a single step to PNG
def sample_to_lowres_estimated_image(self,samples):
# adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these numbers were determined empirically by @keturn
v1_4_latent_rgb_factors = torch.tensor([
# R G B
[ 0.298, 0.207, 0.208], # L1
[ 0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], dtype=samples.dtype, device=samples.device)
latent_image = samples[0].permute(1, 2, 0) @ v1_4_latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def generate_initial_noise(self, seed, width, height):
initial_noise = None
if self.variation_amount > 0 or len(self.with_variations) > 0:
@ -169,6 +210,40 @@ class Generator():
return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -21,6 +21,7 @@ class Embiggen(Generator):
def generate(self,prompt,iterations=1,seed=None,
image_callback=None, step_callback=None,
**kwargs):
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
prompt,
@ -63,6 +64,8 @@ class Embiggen(Generator):
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it
"""
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling

View File

@ -4,14 +4,18 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch
import numpy as np
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
import PIL
from torch import Tensor
from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
@ -25,6 +29,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
@ -32,7 +39,7 @@ class Img2Img(Generator):
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
def make_image(x_T):
# encode (scaled latent)
@ -49,7 +56,9 @@ class Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info,
all_timesteps_count = steps
)
return self.sample_to_image(samples)
@ -68,3 +77,14 @@ class Img2Img(Generator):
shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -2,28 +2,188 @@
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''
import math
import torch
import torchvision.transforms as T
import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter, ImageOps
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision)
# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
# Only fill if there's an alpha layer
if im.mode != 'RGBA':
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()
# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0
)
result = make_image(noise)
return result
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
step_callback=None,inpaint_replace=False,**kwargs):
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, **kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
# Fill missing areas of original image
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
@ -45,7 +205,8 @@ class Inpaint(Img2Img):
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
@ -78,9 +239,77 @@ class Inpaint(Img2Img):
init_latent = self.init_latent
)
return self.sample_to_image(samples)
result = self.sample_to_image(samples)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T)
return result
return make_image
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
# Get numpy version of result
np_image = np.asarray(image, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
return matched_result
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
return corrected_result

View File

@ -0,0 +1,153 @@
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
class Omnibus(Img2Img,Txt2Img):
def __init__(self, model, precision):
super().__init__(model, precision)
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
init_image = None,
mask_image = None,
strength = None,
step_callback=None,
threshold=0.0,
perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, Image.Image):
if init_image.mode != 'RGB':
init_image = init_image.convert('RGB')
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
if init_image is not None and mask_image is not None: # inpainting
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
with torch.no_grad():
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
mask_image,
masked_image,
prompt=prompt,
device=model.device,
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond={"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc_full,
eta = 1.0,
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,
prompt,
device,
num_samples=1):
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [prompt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def get_noise(self, width:int, height:int):
if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)

View File

@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch
import numpy as np
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img(Generator):
def __init__(self, model, precision):
@ -19,7 +21,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height'
"""
self.perlin = perlin
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
@torch.no_grad()
def make_image(x_T):
@ -43,6 +45,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,
@ -59,7 +62,7 @@ class Txt2Img(Generator):
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
@ -74,3 +77,4 @@ class Txt2Img(Generator):
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -5,9 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch
import numpy as np
import math
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from PIL import Image
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
@ -22,31 +24,29 @@ class Txt2Img2Img(Generator):
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
scale_dim = min(width, height)
scale = 512 / scale_dim
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
@torch.no_grad()
def make_image(x_T):
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
def make_image(x_T):
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
#x = self.get_noise(init_width, init_height)
x = x_T
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
@ -60,17 +60,18 @@ class Txt2Img2Img(Generator):
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
)
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing
samples = torch.nn.functional.interpolate(
samples,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
samples,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear"
)
@ -94,6 +95,8 @@ class Txt2Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps
)
if self.free_gpu_mem:
@ -101,8 +104,49 @@ class Txt2Img2Img(Generator):
return self.sample_to_image(samples)
return make_image
# in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision)
result = omnibus.generate(
prompt,
sampler=sampler,
width=init_width,
height=init_height,
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
else:
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True):
@ -116,9 +160,9 @@ class Txt2Img2Img(Generator):
else:
scaled_width = width
scaled_height = height
device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
@ -130,3 +174,4 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)

View File

@ -13,17 +13,17 @@ import gc
import hashlib
import psutil
import transformers
import traceback
import os
from sys import getrefcount
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config
GIGS=2**30
AVG_MODEL_SIZE=2.1*GIGS
DEFAULT_MIN_AVAIL=2*GIGS
DEFAULT_MAX_MODELS=2
class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
@ -36,7 +36,7 @@ class ModelCache(object):
self.config = config
self.precision = precision
self.device = torch.device(device_type)
self.min_avail_mem = min_avail_mem
self.max_loaded_models = max_loaded_models
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
@ -52,7 +52,9 @@ class ModelCache(object):
return None
if self.current_model != model_name:
self.unload_model(self.current_model)
if model_name not in self.models: # make room for a new one
self._make_cache_room()
self.offload_model(self.current_model)
if model_name in self.models:
requested_model = self.models[model_name]['model']
@ -61,8 +63,7 @@ class ModelCache(object):
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else:
self._check_memory()
else: # we're about to load a new model, so potentially offload the least recently used one
try:
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {}
@ -72,8 +73,10 @@ class ModelCache(object):
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}')
return self.get_model(self.current_model)
self.get_model(self.current_model)
return None
self.current_model = model_name
self._push_newest_model(model_name)
@ -84,6 +87,26 @@ class ModelCache(object):
'hash': hash
}
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict:
'''
Return a dict of models in the format:
@ -121,12 +144,23 @@ class ModelCache(object):
else:
print(line)
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
def del_model(self, model_name:str) ->bool:
'''
Delete the named model.
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and a YAML
string will be returned.
On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
'''
omega = self.config
# check that all the required fields are present
@ -139,17 +173,10 @@ class ModelCache(object):
config[field] = model_attributes[field]
omega[model_name] = config
return OmegaConf.to_yaml(omega)
if clobber:
self._invalidate_cached_model(model_name)
return True
def _check_memory(self):
avail_memory = psutil.virtual_memory()[1]
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory:
least_recent_model = self._pop_oldest_model()
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
@ -159,6 +186,7 @@ class ModelCache(object):
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width
height = mconfig.height
@ -188,9 +216,20 @@ class ModelCache(object):
else:
print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae:
if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f' | VAE file {vae} not found. Skipping.')
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device
model.eval()
for m in model.modules():
@ -209,16 +248,67 @@ class ModelCache(object):
)
return model, width, height, model_hash
def unload_model(self, model_name:str):
def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
'''
if model_name not in self.models:
return
print(f'>> Caching model {model_name} in system RAM')
message = f'>> Offloading {model_name} to CPU'
print(message)
model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def _make_cache_room(self):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
def print_vram_usage(self):
if self._has_cuda:
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.offload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model):
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'
@ -243,8 +333,7 @@ class ModelCache(object):
to be the least recently accessed model. Do not
pop the last one, because it is in active use!
'''
if len(self.stack) > 1:
return self.stack.pop(0)
return self.stack.pop(0)
def _push_newest_model(self,model_name:str):
'''

View File

@ -38,7 +38,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)
if metadata:
info.add_text('sd-metadata', json.dumps(metadata))
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path

702
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,702 @@
import string
from typing import Union, Optional
import re
import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
PromptParser - parses prompts
Useful function exports:
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]):
self.children = []
for part in parts:
self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list:
for x in fragment:
self.append(x)
elif issubclass(type(fragment), BaseFragment):
self.children.append(fragment)
elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
self.children.append(Fragment(fragment[0], fragment[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
"""
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
self.original = original
self.edited = edited
default_options = {
's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.0,
't_end': 1.0
}
merged_options = default_options
if options is not None:
shape_freedom = options.pop('shape_freedom', None)
if shape_freedom is not None:
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
#print('converted shape_freedom argument to', merged_options)
merged_options.update(options)
self.options = merged_options
def __repr__(self):
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited \
and other.options == self.options
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", parts)
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for p in prompts:
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for f in p.children:
if isinstance(f, CrossAttentionControlSubstitute):
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x)
for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
def parse_conjunction(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.conjunction.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
#print("flattening", root)
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts)
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
def make_text_fragment(x):
#print("### making fragment for", x)
if type(x[0]) is Fragment:
assert(False)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
escapes = []
for c in escaped_chars_to_ignore:
escapes.append(pp.Literal('\\'+c))
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(escapes + [pp.CharsNotIn(
string.whitespace + escaped_chars_to_ignore,
exact=1
)])
))
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string for {x}")
fragment_string = x[0]
#print(f"ppparsing fragment string \"{fragment_string}\"")
if len(fragment_string.strip()) == 0:
return Fragment('')
if in_quotes:
# escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"')
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
try:
result = pp.Group(pp.MatchFirst([
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
#print("parsed to", result)
return result
except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e)
raise
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')
unquoted_word = (pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
)))
# don't whitespace when the next word starts with +, eg "badly +formed"
+ (pp.White().suppress() |
# don't eat +/-
pp.NotAny(pp.Word('+') | pp.Word('-'))
)
)
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << (lparen +
pp.Or([
(parenthesized_fragment),
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
(pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
pp.White()
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
pp.Empty()
]) + rparen)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_with_parens = pp.Forward()
attention_without_parens = pp.Forward()
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_foot")\
.set_debug(False)
attention_with_parens <<= pp.Group(
lparen +
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
)
+ rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group(pp.MatchFirst([
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot#.leave_whitespace()
]))
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
attention << pp.MatchFirst([attention_with_parens,
attention_without_parens
])
attention.set_name('attention')
def make_attention(x):
#print("entered make_attention with", x)
children = x[0][:-1]
weight_raw = x[0][-1]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
#print("making Attention from", children, "with weight", weight)
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.MatchFirst([
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
empty_string.set_debug(debug_cross_attention_control),
])
# support keyword=number arguments
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
lparen +
(quoted_fragment | attention |
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
#if len(x>2):
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# root prompt definition
debug_root_prompt = False
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt),
attention.set_debug(debug_root_prompt),
quoted_fragment.set_debug(debug_root_prompt),
parenthesized_fragment.set_debug(debug_root_prompt),
unquoted_word.set_debug(debug_root_prompt),
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
) + pp.StringEnd()) \
.set_name('prompt') \
.set_parse_action(lambda x: Prompt(x)) \
.set_debug(debug_root_prompt)
#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
# int weights.
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
def make_prompt_from_quoted_string(x):
#print(' got quoted prompt', x)
x_unquoted = x[0][1:-1]
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
#print(f' b parsing \'{x_unquoted}\'')
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
debug_blend=False
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', 'x` ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -22,6 +22,7 @@ except (ImportError,ModuleNotFoundError):
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
TEXT_EXTENSIONS = ('.txt','.TXT')
CONFIG_EXTENSIONS = ('.yaml','.yml')
COMMANDS = (
'--steps','-s',
@ -54,12 +55,15 @@ COMMANDS = (
'--hires_fix',
'--inpaint_replace','-r',
'--png_compression','-z',
'!fix','!fetch','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model'
'--text_mask','-tm',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model','!del_model',
'!mask',
)
MODEL_COMMANDS = (
'!switch',
'!edit_model',
'!del_model',
)
WEIGHT_COMMANDS = (
'!import_model',
@ -67,16 +71,21 @@ WEIGHT_COMMANDS = (
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
)
TEXT_PATH_COMMANDS=(
'!replay',
)
IMG_FILE_COMMANDS=(
'!fix',
'!fetch',
'!mask',
'--init_img[=\s]','-I',
'--init_mask[=\s]','-M',
'--init_color[=\s]',
'--embedding_path[=\s]',
)
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '('+'|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
class Completer(object):
def __init__(self, options, models=[]):
@ -119,6 +128,9 @@ class Completer(object):
elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [
@ -207,9 +219,24 @@ class Completer(object):
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line
readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:

View File

@ -41,10 +41,12 @@ class CodeFormerRestoration():
cf.eval()
image = image.convert('RGB')
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
face_helper.clean_all()
face_helper.read_image(np.array(image, dtype=np.uint8))
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
@ -71,7 +73,8 @@ class CodeFormerRestoration():
restored_img = face_helper.paste_faces_to_input_image()
res = Image.fromarray(restored_img)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed

View File

@ -55,13 +55,18 @@ class GFPGAN():
image = image.convert('RGB')
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
_, _, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed

View File

@ -32,7 +32,7 @@ class Outcrop(object):
result= self.generate.prompt2image(
orig_opt.prompt,
# seed = orig_opt.seed, # uncomment to make it deterministic
seed = orig_opt.seed, # uncomment to make it deterministic
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
@ -40,8 +40,15 @@ class Outcrop(object):
width = extended_image.width,
height = extended_image.height,
init_img = extended_image,
strength = opt.strength,
image_callback = wrapped_callback,
strength = 0.90,
image_callback = wrapped_callback if image_callback else None,
seam_size = opt.seam_size or 96,
seam_blur = opt.seam_blur or 16,
seam_strength = opt.seam_strength or 0.7,
seam_steps = 20,
tile_size = 32,
color_match = True,
force_outpaint = True, # this just stops the warning about erased regions
)
# swap sampler back
@ -89,23 +96,12 @@ class Outcrop(object):
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
# first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it
# take the top half region, stretch and paste it
top_slice = image.crop(box=(0,0,image.width,pixels//2))
top_slice = top_slice.resize((image.width,pixels))
extended_img.paste(top_slice,box=(0,0))
extended_img.paste((0,0,0),[0,0,image.width,image.height+pixels])
extended_img.paste(image,box=(0,pixels))
# second paste creates a copy of the image displaced pixels downward;
# The overall effect is to create a blurred duplicate of the top portion of
# the image.
extended_img.paste(image,box=(0,pixels))
extended_img = extended_img.filter(filter=ImageFilter.GaussianBlur(radius=pixels//2))
extended_img.paste(image,box=(0,pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2))
alpha.paste(0,(0,0,extended_img.width,pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -60,14 +60,18 @@ class ESRGAN():
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8),
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler='realesrgan',
)
res = Image.fromarray(output)
# Flip the channels back to RGB
res = Image.fromarray(output[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed

30
ldm/invoke/seamless.py Normal file
View File

@ -0,0 +1,30 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
"""
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, 'asymmetric_padding_mode'):
del m.asymmetric_padding_mode
if hasattr(m, 'asymmetric_padding'):
del m.asymmetric_padding

View File

@ -34,6 +34,7 @@ def build_opt(post_data, seed, gfpgan_model_exists):
setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0)
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'progress_latents', 'progress_latents' in post_data)
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'threshold', float(post_data['threshold']))
setattr(opt, 'perlin', float(post_data['perlin']))
@ -227,8 +228,13 @@ class DreamServer(BaseHTTPRequestHandler):
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
nonlocal step_index
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
image = self.model.sample_to_image(sample)
wants_progress_latents = opt.progress_latents
wants_progress_image = opt.progress_image and step % 5 == 0
if (wants_progress_image | wants_progress_latents) and step < opt.steps - 1:
image = self.model.sample_to_image(sample) if wants_progress_image \
else self.model.sample_to_lowres_estimated_image(sample)
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'

130
ldm/invoke/txt2mask.py Normal file
View File

@ -0,0 +1,130 @@
'''Makes available the Txt2Mask class, which assists in the automatic
assignment of masks via text prompt using clipseg.
Here is typical usage:
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from PIL import Image
txt2mask = Txt2Mask(self.device)
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
# this will return a grayscale Image of the segmented data
grayscale = segmented.to_grayscale()
# this will return a semi-transparent image in which the
# selected object(s) are opaque and the rest is at various
# levels of transparency
transparent = segmented.to_transparent()
# this will return a masked image suitable for use in inpainting:
mask = segmented.to_mask(threshold=0.5)
The threshold used in the call to to_mask() selects pixels for use in
the mask that exceed the indicated confidence threshold. Values range
from 0.0 to 1.0. The higher the threshold, the more confident the
algorithm is. In limited testing, I have found that values around 0.5
work fine.
'''
import torch
import numpy as np
from clipseg_models.clipseg import CLIPDensePredT
from einops import rearrange, repeat
from PIL import Image, ImageOps
from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
def __init__(self, image:Image, heatmap:torch.Tensor):
self.heatmap = heatmap
self.image = image
def to_grayscale(self,invert:bool=False)->Image:
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
def to_mask(self,threshold:float=0.5)->Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self,invert:bool=False)->Image:
transparent_image = self.image.copy()
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite. Thus invert.
gs = self.to_grayscale(not invert)
transparent_image.putalpha(gs)
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap:Image)->Image:
size = self.image.width if (self.image.width > self.image.height) else self.image.height
resized_image = heatmap.resize(
(size,size),
resample=Image.Resampling.LANCZOS
)
return resized_image.crop((0,0,self.image.width,self.image.height))
class Txt2Mask(object):
'''
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference')
self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
self.model.eval()
# initially we keep everything in cpu to conserve space
self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale:
'''
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
'''
self._to_device(self.device)
prompts = [prompt] # right now we operate on just a single prompt at a time
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0)
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
heatmap = torch.sigmoid(preds[0][0]).cpu()
self._to_device('cpu')
return SegmentedGrayscale(image, heatmap)
def _to_device(self, device):
self.model.to(device)
def _scale_and_crop(self, image:Image)->Image:
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width
else:
scale = CLIPSEG_SIZE / image.height
scaled_image.paste(
image.resize(
(int(scale * image.width),
int(scale * image.height)
),
resample=Image.Resampling.LANCZOS
),box=(0,0)
)
return scaled_image

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -0,0 +1,201 @@
from enum import Enum
import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context:
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.arguments = arguments
self.step_count = step_count
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
cross_attention_control_args: Arguments
):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
cls.inject_attention_function(model)
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
class CrossAttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
if clear_attn_slice:
m.last_attn_slice = None
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
if dim is None:
last_attn_slice = self.last_attn_slice
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
else:
last_attn_slice = self.last_attn_slice[offset]
if self.last_attn_slice_mask is None:
# just use everything
attn_slice = last_attn_slice
else:
last_attn_slice_mask = self.last_attn_slice_mask
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
this_attn_slice = attn_slice
this_attn_slice_mask = 1 - last_attn_slice_mask
attn_slice = this_attn_slice * this_attn_slice_mask + \
remapped_last_attn_slice * last_attn_slice_mask
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
else:
if self.last_attn_slice is None:
self.last_attn_slice = { offset: attn_slice }
else:
self.last_attn_slice[offset] = attn_slice
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)

View File

@ -1,10 +1,7 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
@ -12,6 +9,21 @@ class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps,device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# This is the central routine
@torch.no_grad()
def p_sample(
@ -29,6 +41,7 @@ class DDIMSampler(Sampler):
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -37,16 +50,17 @@ class DDIMSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(
x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index
)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib
from ldm.util import (
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@ -820,21 +821,21 @@ class LatentDiffusion(DDPM):
)
return self.scale_factor * z
def get_learned_conditioning(self, c):
def get_learned_conditioning(self, c, **kwargs):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager
c, embedding_manager=self.embedding_manager,**kwargs
)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
c = self.cond_stage_model(c, **kwargs)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
return c
def meshgrid(self, h, w):
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(
self,
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -1,16 +1,16 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
from ldm.util import rand_perlin_2d
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
from torch import nn
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
# at this threshold, the scheduler will stop using the Karras
# noise schedule and start using the model's schedule
STEP_THRESHOLD = 30
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
@ -33,12 +33,21 @@ class CFGDenoiser(nn.Module):
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
@ -46,8 +55,7 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
@ -60,16 +68,9 @@ class KSampler(Sampler):
self.sigmas = None
self.ds = None
self.s_in = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
if self.karras_max is None:
self.karras_max = STEP_THRESHOLD
def make_schedule(
self,
@ -98,8 +99,13 @@ class KSampler(Sampler):
rho=7.,
device=self.device,
)
self.sigmas = self.model_sigmas
#self.sigmas = self.karras_sigmas
if ddim_num_steps >= self.karras_max:
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
self.sigmas = self.model_sigmas
else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
@ -118,6 +124,7 @@ class KSampler(Sampler):
use_original_steps=False,
init_latent = None,
mask = None,
**kwargs
):
samples,_ = self.sample(
batch_size = 1,
@ -129,7 +136,8 @@ class KSampler(Sampler):
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
mask = mask,
**kwargs
)
return samples
@ -163,6 +171,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -181,7 +190,6 @@ class KSampler(Sampler):
)
# sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.sigmas)
sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add
@ -195,19 +203,21 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
return (
sampling_result = (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
# a workaround is found.
@ -220,6 +230,7 @@ class KSampler(Sampler):
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs,
):
if self.model_wrap is None:
@ -245,6 +256,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,
@ -269,7 +281,7 @@ class KSampler(Sampler):
else:
return x
def prepare_to_sample(self,t_enc):
def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None
@ -281,3 +293,6 @@ class KSampler(Sampler):
'''
return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -5,6 +5,7 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
@ -13,6 +14,18 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# this is the essential routine
@torch.no_grad()
def p_sample(
@ -32,6 +45,7 @@ class PLMSSampler(Sampler):
unconditional_conditioning=None,
old_eps=[],
t_next=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -41,18 +55,15 @@ class PLMSSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(

View File

@ -2,13 +2,13 @@
ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
'''
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
@ -24,6 +24,8 @@ class Sampler(object):
self.ddpm_num_timesteps = steps
self.schedule = schedule
self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@ -158,6 +160,18 @@ class Sampler(object):
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None:
self.make_schedule(
@ -190,10 +204,11 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
**kwargs
)
return samples, intermediates
#torch.no_grad()
@torch.no_grad()
def do_sampling(
self,
cond,
@ -214,6 +229,7 @@ class Sampler(object):
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
**kwargs
):
b = shape[0]
time_range = (
@ -231,7 +247,7 @@ class Sampler(object):
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all
@ -274,6 +290,7 @@ class Sampler(object):
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
step_count=steps
)
img, pred_x0, e_t = outs
@ -305,8 +322,9 @@ class Sampler(object):
use_original_steps=False,
init_latent = None,
mask = None,
all_timesteps_count = None,
**kwargs
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
@ -321,7 +339,7 @@ class Sampler(object):
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@ -353,6 +371,7 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
t_next = ts_next,
step_count=len(self.ddim_timesteps)
)
x_dec, pred_x0, e_t = outs
@ -411,3 +430,21 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts)
'''
return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key
def uses_inpainting_model(self)->bool:
return self.conditioning_key() in ('hybrid','concat')
def adjust_settings(self,**kwargs):
'''
This is a catch-all method for adjusting any instance variables
after the sampler is instantiated. No type-checking performed
here, so use with care!
'''
for k in kwargs.keys():
try:
setattr(self,k,kwargs[k])
except AttributeError:
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')

View File

@ -0,0 +1,225 @@
from math import ceil
from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
class InvokeAIDiffuserComponent:
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
'''
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
self.cross_attention_control_args = cross_attention_control_args
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.model = model
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict()
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]])
for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model)
return unconditioned_next_x, conditioned_next_x
except RuntimeError:
# make sure we clean out the attention slices we're storing on the model
# TODO don't store things on the model
CrossAttentionControl.clear_requests(self.model)
raise
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -1,5 +1,7 @@
from inspect import isfunction
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn, einsum
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
@ -170,46 +173,71 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
self.attention_slice_wrangler = None
def einsum_op_slice_0(self, q, k, v, slice_size):
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_1(q, k, v, slice_size)
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_0(q, k, v, 1)
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
@ -221,7 +249,7 @@ class CrossAttention(nn.Module):
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(self, q, k, v):
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
@ -244,8 +272,13 @@ class CrossAttention(nn.Module):
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
r = self.get_attention_mem_efficient(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module):

View File

@ -64,7 +64,8 @@ def make_ddim_timesteps(
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
# ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
if c < 1:
c = 1
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
elif ddim_discr_method == 'quad':
ddim_timesteps = (

View File

@ -1,3 +1,5 @@
import math
import torch
import torch.nn as nn
from functools import partial
@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs):
return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens"
def forward(self, text: list, **kwargs):
'''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied.
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
'''
if self.fragment_weights_key not in kwargs:
# fallback to base class implementation
return super().forward(text, **kwargs)
fragment_weights = kwargs[self.fragment_weights_key]
# self.transformer doesn't like receiving "fragment_weights" as an argument
kwargs.pop(self.fragment_weights_key)
should_return_tokens = False
if self.return_tokens_key in kwargs:
should_return_tokens = kwargs.get(self.return_tokens_key, False)
# self.transformer doesn't like having extra kwargs
kwargs.pop(self.return_tokens_key)
batch_z = None
batch_tokens = None
for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens:
return batch_z, batch_tokens
else:
return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
if include_start_and_end_markers:
return tokens
else:
return [x[1:-1] for x in tokens]
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
'''
:param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return:
'''
# empty is meaningful
if len(fragments) == 0 and len(weights) == 0:
fragments = ['']
weights = [1]
item_encodings = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=True,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = []
#print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
#print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
#print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if (len(all_tokens) + 2) > self.max_length:
excess_token_count = (len(all_tokens) + 2) - self.max_length
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
per_token_weights = per_token_weights[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens)
all_tokens.insert(0, self.tokenizer.bos_token_id)
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
per_token_weights.insert(0, 1)
per_token_weights.extend([1] * pad_length)
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
'''
Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0],
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)['input_ids'].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])
return weighted_z
else:
original_mean = z.mean()
z *= batch_weights_expanded
after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean
z *= mean_correction_factor
return z
class FrozenCLIPTextEmbedder(nn.Module):
"""