Merge branch 'development' into fix_generate.py

This commit is contained in:
Lincoln Stein
2022-11-05 12:47:35 -07:00
committed by GitHub
587 changed files with 54471 additions and 6579 deletions

View File

@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@ -1,96 +0,0 @@
'''
This module handles the generation of the conditioning tensors, including management of
weighted subprompts.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
import re
import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
uc = model.get_learned_conditioning([''])
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
prompt, skip_normalize
)
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens)
c = model.get_learned_conditioning([prompt])
return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / len(parsed_prompts)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -1,20 +0,0 @@
import torch
from torch import autocast
from contextlib import contextmanager, nullcontext
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
def choose_autocast_device(device):
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type in ('cuda','cpu'):
return device_type,autocast
else:
return 'cpu',nullcontext

View File

@ -1,4 +0,0 @@
'''
Initialization file for the ldm.dream.generator package
'''
from .base import Generator

View File

@ -1,72 +0,0 @@
'''
ldm.dream.generator.txt2img descends from ldm.dream.generator
'''
import torch
import numpy as np
from ldm.dream.devices import choose_autocast_device
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator):
def __init__(self,model):
super().__init__(model)
self.init_latent = None # by get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
device_type,scope = choose_autocast_device(self.model.device)
with scope(device_type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
return self.sample_to_image(samples)
return make_image
def get_noise(self,width,height):
device = self.model.device
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(device)
else:
return torch.randn_like(init_latent, device=device)

View File

@ -1,77 +0,0 @@
'''
ldm.dream.generator.inpaint descends from ldm.dream.generator
'''
import torch
import numpy as np
from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast_device
from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
class Inpaint(Img2Img):
def __init__(self,model):
self.init_latent = None
super().__init__(model)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
device_type,scope = choose_autocast_device(self.model.device)
with scope(device_type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)
return make_image

View File

@ -1,127 +0,0 @@
"""
Readline helper functions for dream.py (linux and mac only).
"""
import os
import re
import atexit
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer:
def __init__(self, options):
self.options = sorted(options)
return
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img','-M','--init_mask')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
return self._path_completions(text, state, ())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self, text, state, extensions):
# get the path so far
# TODO: replace this mess with a regular expression match
if text.startswith('-I'):
path = text.replace('-I', '', 1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=', '', 1).lstrip()
elif text.startswith('--init_mask='):
path = text.replace('--init_mask=', '', 1).lstrip()
elif text.startswith('-M'):
path = text.replace('-M', '', 1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path) == 0:
matches.append(text + './')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n) > 1:
continue
full_path = os.path.join(dir, n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text), n))
try:
response = matches[state]
except IndexError:
response = None
return response
if readline_available:
readline.set_completer(
Completer(
[
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--grid','-g',
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--gfpgan_strength','-G',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','t',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)

File diff suppressed because it is too large Load Diff

View File

@ -1,167 +0,0 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
from scripts.dream import create_argv_parser
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)
def run_gfpgan(image, strength, seed, upsampler_scale=4):
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
gfpgan = None
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
if not gfpgan_model_exists:
raise Exception('GFPGAN model not found at path ' + model_path)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
gfpgan = GFPGANer(
model_path=model_path,
upscale=upsampler_scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
return image
image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gfpgan = None
return res
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
else:
bg_upsampler = None
return bg_upsampler
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
except Exception:
import traceback
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, img_mode = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler=opt.gfpgan_bg_upsampler,
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

1105
ldm/invoke/args.py Normal file

File diff suppressed because it is too large Load Diff

195
ldm/invoke/conditioning.py Normal file
View File

@ -0,0 +1,195 @@
'''
This module handles the generation of the conditioning tensors.
Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
'''
import re
from difflib import SequenceMatcher
from typing import Union
import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for i,flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
)
)
def build_token_edit_opcodes(original_tokens, edited_tokens):
original_tokens = original_tokens.cpu().numpy()[0]
edited_tokens = edited_tokens.cpu().numpy()[0]
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
assert isinstance(uncond, dict)
assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond:
if isinstance(cond[k], list):
cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened

27
ldm/invoke/devices.py Normal file
View File

@ -0,0 +1,27 @@
import torch
from torch import autocast
from contextlib import nullcontext
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
def choose_precision(device) -> str:
'''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device)
if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
return 'float16'
return 'float32'
def choose_autocast(precision):
'''Returns an autocast context or nullcontext for the given precision string'''
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext

View File

@ -0,0 +1,4 @@
'''
Initialization file for the ldm.invoke.generator package
'''
from .base import Generator

View File

@ -1,26 +1,36 @@
'''
Base class for ldm.dream.generator.*
Base class for ldm.invoke.generator.*
including img2img, txt2img, and inpaint
'''
import torch
import numpy as np
import random
import os
import traceback
from tqdm import tqdm, trange
from PIL import Image
from PIL import Image, ImageFilter
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.dream.devices import choose_autocast_device
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator():
def __init__(self,model):
self.model = model
self.seed = None
self.latent_channels = model.channels
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.variation_amount = 0
self.with_variations = []
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@ -35,23 +45,31 @@ class Generator():
self.variation_amount = variation_amount
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None,
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
device_type,scope = choose_autocast_device(self.model.device)
make_image = self.get_make_image(
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
sampler = sampler,
init_image = init_image,
width = width,
height = height,
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
**kwargs
)
results = []
seed = seed if seed else self.new_seed()
seed = seed if seed is not None else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(device_type), self.model.ema_scope():
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'):
x_T = None
if self.variation_amount > 0:
@ -63,21 +81,30 @@ class Generator():
x_T = initial_noise
else:
seed_everything(seed)
if self.model.device.type == 'mps':
try:
x_T = self.get_noise(width,height)
except:
print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed)
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
def sample_to_image(self,samples)->Image.Image:
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
Given samples returned from a sampler, converts
it into a PIL Image
"""
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -89,6 +116,29 @@ class Generator():
)
return Image.fromarray(x_sample.astype(np.uint8))
# write an approximate RGB image from latent samples for a single step to PNG
def sample_to_lowres_estimated_image(self,samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor([
# R G B
[ 0.3444, 0.1385, 0.0670], # L1
[ 0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445] # L4
], dtype=samples.dtype, device=samples.device)
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def generate_initial_noise(self, seed, width, height):
initial_noise = None
if self.variation_amount > 0 or len(self.with_variations) > 0:
@ -115,6 +165,10 @@ class Generator():
"""
raise NotImplementedError("get_noise() must be implemented in a descendent class")
def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
@ -156,3 +210,48 @@ class Generator():
return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):
image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or '.'
if not os.path.exists(dirname):
print(f'** creating directory {dirname}')
os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG')

View File

@ -0,0 +1,501 @@
'''
ldm.invoke.generator.embiggen descends from ldm.invoke.generator
and generates with ldm.invoke.generator.img2img
'''
import torch
import numpy as np
from tqdm import trange
from PIL import Image
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.ddim import DDIMSampler
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None
# Replace generate because Embiggen doesn't need/use most of what it does normallly
def generate(self,prompt,iterations=1,seed=None,
image_callback=None, step_callback=None,
**kwargs):
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
prompt,
step_callback = step_callback,
**kwargs
)
results = []
seed = seed if seed else self.new_seed()
# Noise will be generated by the Img2Img generator when called
with scope(self.model.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
# make_image will call Img2Img which will do the equivalent of get_noise itself
image = make_image()
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed)
seed = self.new_seed()
return results
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_img,
strength,
width,
height,
embiggen,
embiggen_tiles,
step_callback=None,
**kwargs
):
"""
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it
"""
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
print(
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
# and then sort them, because... people.
if embiggen_tiles:
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
embiggen_tiles.sort()
if strength >= 0.5:
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
# Prep img2img generator, since we wrap over it
gen_img2img = Img2Img(self.model,self.precision)
# Open original init image (not a tensor) to manipulate
initsuperimage = Image.open(init_img)
with Image.open(init_img) as img:
initsuperimage = img.convert('RGB')
# Size of the target super init image in pixels
initsuperwidth, initsuperheight = initsuperimage.size
# Increase by scaling factor if not already resized, using ESRGAN as able
if embiggen[0] != 1.0:
initsuperwidth = round(initsuperwidth*embiggen[0])
initsuperheight = round(initsuperheight*embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.invoke.restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
print(
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
if embiggen[0] > 2:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
4, # upscale scale
)
else:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
2, # upscale scale
)
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
# but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
# Use width and height as tile widths and height
# Determine buffer size in pixels
if embiggen[2] < 1:
if embiggen[2] < 0:
embiggen[2] = 0
overlap_size_x = round(embiggen[2] * width)
overlap_size_y = round(embiggen[2] * height)
else:
overlap_size_x = round(embiggen[2])
overlap_size_y = round(embiggen[2])
# With overall image width and height known, determine how many tiles we need
def ceildiv(a, b):
return -1 * (-a // b)
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
# (width - overlap_size_x) is how much new we can fill with a single tile
emb_tiles_x = 1
emb_tiles_y = 1
if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width,
width - overlap_size_x) + 1
if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height,
height - overlap_size_y) + 1
# Sanity
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
# Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent
agradientL = Image.linear_gradient('L').rotate(
90).resize((overlap_size_x, height))
# agradientT is Top-side transparent
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
agradientC = Image.new('L', (256, 256))
for y in range(256):
for x in range(256):
# Find distance to lower right corner (numpy takes arrays)
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
# Clamp values to max 255
if distanceToLR > 255:
distanceToLR = 255
#Place the pixel as invert of distance
agradientC.putpixel((x, y), round(255 - distanceToLR))
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
agradientAsymC = Image.new('L', (256, 256))
for y in range(256):
for x in range(256):
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
#Clamp values
value = max(0, value)
value = min(255, value)
agradientAsymC.putpixel((x, y), value)
# Create alpha layers default fully white
alphaLayerL = Image.new("L", (width, height), 255)
alphaLayerT = Image.new("L", (width, height), 255)
alphaLayerLTC = Image.new("L", (width, height), 255)
# Paste gradients into alpha layers
alphaLayerL.paste(agradientL, (0, 0))
alphaLayerT.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientL, (0, 0))
alphaLayerLTC.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
alphaLayerTaC = alphaLayerT.copy()
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
alphaLayerLTaC = alphaLayerLTC.copy()
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
if embiggen_tiles:
# Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
# Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerRBC.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerLBC.paste(agradientC.rotate(90).resize(
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# All but X layers
alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerABT.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerABL.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
# All-around layer
alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerAA.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# Clean up temporary gradients
del agradientL
del agradientT
del agradientC
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
else:
print(
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
emb_tile_store = []
# Although we could use the same seed for every tile for determinism, at higher strengths this may
# produce duplicated structures for each tile and make the tiling effect more obvious
# instead track and iterate a local seed we pass to Img2Img
seed = self.seed
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
for tile in range(emb_tiles_x * emb_tiles_y):
# Don't iterate on first tile
if tile != 0:
if seed < seedintlimit:
seed += 1
else:
seed = 0
# Determine if this is a re-run and replace
if embiggen_tiles and not tile in embiggen_tiles:
continue
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
# Determine bounds to cut up the init image
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
right = left + width
bottom = top + height
# Cropped image of above dimension (does not modify the original)
newinitimage = initsuperimage.crop((left, top, right, bottom))
# DEBUG:
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
print(
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
else:
print(
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
# create a torch tensor from an Image
newinitimage = np.array(
newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0
newinitimage = newinitimage.to(self.model.device)
tile_results = gen_img2img.generate(
prompt,
iterations = 1,
seed = seed,
sampler = DDIMSampler(self.model, device=self.model.device),
steps = steps,
cfg_scale = cfg_scale,
conditioning = conditioning,
ddim_eta = ddim_eta,
image_callback = None, # called only after the final image is generated
step_callback = step_callback, # called after each intermediate image is generated
width = width,
height = height,
init_image = newinitimage, # notice that init_image is different from init_img
mask_image = None,
strength = strength,
)
emb_tile_store.append(tile_results[0][0])
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
del newinitimage
# Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
outputsuperimage = Image.new(
"RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles:
outputsuperimage.alpha_composite(
initsuperimage.convert('RGBA'), (0, 0))
for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles:
if tile in embiggen_tiles:
intileimage = emb_tile_store.pop(0)
else:
continue
else:
intileimage = emb_tile_store[tile]
intileimage = intileimage.convert('RGBA')
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
left = 0
top = 0
else:
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i *
(width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
# Handle gradients for various conditions
# Handle emb_rerun case
if embiggen_tiles:
# top of image
if emb_row_i == 0:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerB)
# Otherwise do nothing on this tile
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerR)
else:
intileimage.putalpha(alphaLayerRBC)
elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerLR)
else:
intileimage.putalpha(alphaLayerABT)
# bottom of image
elif emb_row_i == emb_tiles_y - 1:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerRTC)
elif emb_column_i == emb_tiles_x - 1:
# No tiles to look ahead to
intileimage.putalpha(alphaLayerLTC)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABB)
# vertical middle of image
else:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerTB)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerRTC)
else:
intileimage.putalpha(alphaLayerABL)
elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABR)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABR)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerABB)
else:
intileimage.putalpha(alphaLayerAA)
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
else:
if emb_row_i == 0 and emb_column_i >= 1:
intileimage.putalpha(alphaLayerL)
elif emb_row_i >= 1 and emb_column_i == 0:
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerT)
else:
intileimage.putalpha(alphaLayerTaC)
else:
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerLTaC)
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
# after internal loops and patching up return Embiggen image
return outputsuperimage
# end of function declaration
return make_image

View File

@ -0,0 +1,90 @@
'''
ldm.invoke.generator.img2img descends from ldm.invoke.generator
'''
import torch
import numpy as np
import PIL
from torch import Tensor
from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c, extra_conditioning_info = conditioning
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info,
all_timesteps_count = steps
)
return self.sample_to_image(samples)
return make_image
def get_noise(self,width,height):
device = self.model.device
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps':
x = torch.randn_like(init_latent, device='cpu').to(device)
else:
x = torch.randn_like(init_latent, device=device)
if self.perlin > 0.0:
shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -0,0 +1,315 @@
'''
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''
import math
import torch
import torchvision.transforms as T
import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter, ImageOps
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision)
# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
# Only fill if there's an alpha layer
if im.mode != 'RGBA':
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()
# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0
)
result = make_image(noise)
return result
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, **kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
# Fill missing areas of original image
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
f">> Using recommended DDIM sampler for inpainting."
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# to replace masked area with latent noise, weighted by inpaint_replace strength
if inpaint_replace > 0.0:
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
z_enc = z_enc * mask_image + masked_region
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
result = self.sample_to_image(samples)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T)
return result
return make_image
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
# Get numpy version of result
np_image = np.asarray(image, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
return matched_result
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
return corrected_result

View File

@ -0,0 +1,153 @@
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
class Omnibus(Img2Img,Txt2Img):
def __init__(self, model, precision):
super().__init__(model, precision)
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
init_image = None,
mask_image = None,
strength = None,
step_callback=None,
threshold=0.0,
perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, Image.Image):
if init_image.mode != 'RGB':
init_image = init_image.convert('RGB')
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
if init_image is not None and mask_image is not None: # inpainting
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
with torch.no_grad():
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
mask_image,
masked_image,
prompt=prompt,
device=model.device,
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond={"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc_full,
eta = 1.0,
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,
prompt,
device,
num_samples=1):
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [prompt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def get_noise(self, width:int, height:int):
if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)

View File

@ -1,24 +1,27 @@
'''
ldm.dream.generator.txt2img inherits from ldm.dream.generator
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
'''
import torch
import numpy as np
from ldm.dream.generator.base import Generator
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img(Generator):
def __init__(self,model):
super().__init__(model)
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,**kwargs):
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
self.perlin = perlin
uc, c, extra_conditioning_info = conditioning
@torch.no_grad()
def make_image(x_T):
@ -27,6 +30,12 @@ class Txt2Img(Generator):
height // self.downsampling_factor,
width // self.downsampling_factor,
]
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
@ -36,9 +45,15 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
@ -47,15 +62,19 @@ class Txt2Img(Generator):
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -0,0 +1,186 @@
'''
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
'''
import torch
import numpy as np
import math
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.image_util import InitImageResizer
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c, extra_conditioning_info = conditioning
scale_dim = min(width, height)
scale = 512 / scale_dim
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
@torch.no_grad()
def make_image(x_T):
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
)
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing
image = self.sample_to_image(samples)
image = InitImageResizer(image).resize(width, height)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
image = image.to(self.model.device)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
samples = self.model.get_first_stage_encoding(
self.model.encode_first_stage(image)
) # move back to latent space
t_enc = int(strength * steps)
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
ddim_sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
z_enc = ddim_sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=self.get_noise(width,height,False)
)
# decode it
samples = ddim_sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
# in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision)
result = omnibus.generate(
prompt,
sampler=sampler,
width=init_width,
height=init_height,
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
else:
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True):
# print(f"Get noise: {width}x{height}")
if scale:
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
scaled_width = math.ceil(scale * width / 64) * 64
scaled_height = math.ceil(scale * height / 64) * 64
else:
scaled_width = width
scaled_height = height
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)

66
ldm/invoke/log.py Normal file
View File

@ -0,0 +1,66 @@
"""
Functions for better format logging
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
1 write_log_message -- Writes a message to the console
2 write_log_files -- Writes a message to files
2.1 write_log_default -- File in plain text
2.2 write_log_txt -- File in txt format
2.3 write_log_markdown -- File in markdown format
"""
import os
def write_log(results, log_path, file_types, output_cntr):
"""
logs the name of the output image, prompt, and prompt args to the terminal and files
"""
output_cntr = write_log_message(results, output_cntr)
write_log_files(results, log_path, file_types)
return output_cntr
def write_log_message(results, output_cntr):
"""logs to the terminal"""
if len(results) == 0:
return output_cntr
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
if len(log_lines)>1:
subcntr = 1
for l in log_lines:
print(f"[{output_cntr}.{subcntr}] {l}", end="")
subcntr += 1
else:
print(f"[{output_cntr}] {log_lines[0]}", end="")
return output_cntr+1
def write_log_files(results, log_path, file_types):
for file_type in file_types:
if file_type == "txt":
write_log_txt(log_path, results)
elif file_type == "md" or file_type == "markdown":
write_log_markdown(log_path, results)
else:
print(f"'{file_type}' format is not supported, so write in plain text")
write_log_default(log_path, results, file_type)
def write_log_default(log_path, results, file_type):
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
file.writelines(plain_txt_lines)
def write_log_txt(log_path, results):
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + ".txt", "a", encoding="utf-8") as file:
file.writelines(txt_lines)
def write_log_markdown(log_path, results):
md_lines = []
for path, prompt in results:
file_name = os.path.basename(path)
md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
with open(log_path + ".md", "a", encoding="utf-8") as file:
file.writelines(md_lines)

370
ldm/invoke/model_cache.py Normal file
View File

@ -0,0 +1,370 @@
'''
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed.
'''
import torch
import os
import io
import time
import gc
import hashlib
import psutil
import transformers
import traceback
import os
from sys import getrefcount
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config
DEFAULT_MAX_MODELS=2
class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
'''
# prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error()
self.config = config
self.precision = precision
self.device = torch.device(device_type)
self.max_loaded_models = max_loaded_models
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
def get_model(self, model_name:str):
'''
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
'''
if model_name not in self.config:
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return None
if self.current_model != model_name:
if model_name not in self.models: # make room for a new one
self._make_cache_room()
self.offload_model(self.current_model)
if model_name in self.models:
requested_model = self.models[model_name]['model']
print(f'>> Retrieving model {model_name} from system RAM cache')
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else: # we're about to load a new model, so potentially offload the least recently used one
try:
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {}
self.models[model_name]['model'] = requested_model
self.models[model_name]['width'] = width
self.models[model_name]['height'] = height
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}')
self.get_model(self.current_model)
return None
self.current_model = model_name
self._push_newest_model(model_name)
return {
'model':requested_model,
'width':width,
'height':height,
'hash': hash
}
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict:
'''
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
},
model_name2: { etc }
'''
result = {}
for name in self.config:
try:
description = self.config[name].description
except ConfigAttributeError:
description = '<no description>'
if self.current_model == name:
status = 'active'
elif name in self.models:
status = 'cached'
else:
status = 'not loaded'
result[name]={}
result[name]['status']=status
result[name]['description']=description
return result
def print_models(self):
'''
Print a table of models, their descriptions, and load status
'''
models = self.list_models()
for name in models:
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
if models[name]['status'] == 'active':
print(f'\033[1m{line}\033[0m')
else:
print(line)
def del_model(self, model_name:str) ->bool:
'''
Delete the named model.
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
'''
omega = self.config
# check that all the required fields are present
for field in ('description','weights','height','width','config'):
assert field in model_attributes, f'required field {field} is missing'
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
config = omega[model_name] if model_name in omega else {}
for field in model_attributes:
config[field] = model_attributes[field]
omega[model_name] = config
if clobber:
self._invalidate_cached_model(model_name)
return True
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
return None
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width
height = mconfig.height
print(f'>> Loading {model_name} from {weights}')
# for usage statistics
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
tic = time.time()
# this does the work
c = OmegaConf.load(config)
with open(weights,'rb') as f:
weight_bytes = f.read()
model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
if self.precision == 'float16':
print(' | Using faster float16 precision')
model.to(torch.float16)
else:
print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae:
if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f' | VAE file {vae} not found. Skipping.')
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device
model.eval()
for m in model.modules():
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
'''
if model_name not in self.models:
return
message = f'>> Offloading {model_name} to CPU'
print(message)
model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def _make_cache_room(self):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
def print_vram_usage(self):
if self._has_cuda:
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.offload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model):
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
return model.to('cpu')
else:
return model
def _model_from_cpu(self,model):
if self.device != 'cpu':
model.to(self.device)
model.first_stage_model.to(self.device)
model.cond_stage_model.to(self.device)
model.cond_stage_model.device = self.device
return model
def _pop_oldest_model(self):
'''
Remove the first element of the FIFO, which ought
to be the least recently accessed model. Do not
pop the last one, because it is in active use!
'''
return self.stack.pop(0)
def _push_newest_model(self,model_name:str):
'''
Maintain a simple FIFO. First element is always the
least recent, and last element is always the most recent.
'''
try:
self.stack.remove(model_name)
except ValueError:
pass
self.stack.append(model_name)
def _has_cuda(self):
return self.device.type == 'cuda'
def _cached_sha256(self,path,data):
dirname = os.path.dirname(path)
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f:
hash = f.read()
return hash
print(f'>> Calculating sha256 hash of weights file')
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f:
f.write(hash)
return hash

View File

@ -3,12 +3,13 @@ Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
into the PNG.
PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches.
Exports function retrieve_metadata(path)
"""
import os
import re
from PIL import PngImagePlugin
import json
from PIL import PngImagePlugin, Image
# -------------------image generation utils-----
@ -32,13 +33,39 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, prompt, name):
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', prompt)
image.save(path, 'PNG', pnginfo=info)
info.add_text('Dream', dream_prompt)
if metadata:
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path
def retrieve_metadata(self,img_basename):
'''
Given a PNG filename stored in outdir, returns the "sd-metadata"
metadata stored there, as a dict
'''
path = os.path.join(self.outdir,img_basename)
all_metadata = retrieve_metadata(path)
return all_metadata['sd-metadata']
def retrieve_metadata(img_path):
'''
Given a path to a PNG image, returns the "sd-metadata"
metadata stored there, as a dict
'''
im = Image.open(img_path)
md = im.text.get('sd-metadata', '{}')
dream_prompt = im.text.get('Dream', '')
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt}
def write_metadata(img_path:str, meta:dict):
im = Image.open(img_path)
info = PngImagePlugin.PngInfo()
info.add_text('sd-metadata', json.dumps(meta))
im.save(img_path,'PNG',pnginfo=info)
class PromptFormatter:
def __init__(self, t2i, opt):

667
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,667 @@
import string
from typing import Union, Optional
import re
import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
PromptParser - parses prompts
Useful function exports:
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]):
self.children = []
for part in parts:
self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list:
for x in fragment:
self.append(x)
elif issubclass(type(fragment), BaseFragment):
self.children.append(fragment)
elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
self.children.append(Fragment(fragment[0], fragment[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list):
if type(weight) is not float:
raise PromptParser.ParsingException(
f"Attention weight must be float (got {type(weight).__name__} {weight})")
self.weight = weight
if type(children) is not list:
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
assert(type(children) is list)
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:{self.children} * {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
"""
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: list, edited: list, options: dict=None):
self.original = original if len(original)>0 else [Fragment('')]
self.edited = edited if len(edited)>0 else [Fragment('')]
default_options = {
's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.0,
't_end': 1.0
}
merged_options = default_options
if options is not None:
shape_freedom = options.pop('shape_freedom', None)
if shape_freedom is not None:
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
#print('converted shape_freedom argument to', merged_options)
merged_options.update(options)
self.options = merged_options
def __repr__(self):
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited \
and other.options == self.options
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for p in prompts:
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for f in p.children:
if isinstance(f, CrossAttentionControlSubstitute):
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x)
for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
class UnrecognizedOperatorException(ParsingException):
def __init__(self, operator:str):
super().__init__("Unrecognized operator: " + operator)
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
def parse_conjunction(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.conjunction.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
verbose and print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults or type(node) is list:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
verbose and print("flattening", root)
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
verbose and print("flattened to", flattened_parts)
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
def make_operator_object(x):
#print('making operator for', x)
target = x[0]
operator = x[1]
arguments = x[2]
if operator == '.attend':
weight_raw = arguments[0]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
return Attention(weight=weight, children=[x for x in x[0]])
elif operator == '.swap':
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
elif operator == '.blend':
prompts = [Prompt(p) for p in x[0]]
weights_raw = x[2]
normalize_weights = True
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
normalize_weights = False
weights_raw = weights_raw[:-1]
weights = [float(w[0]) for w in weights_raw]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
elif operator == '.and' or operator == '.add':
prompts = [Prompt(p) for p in x[0]]
weights = [float(w[0]) for w in x[2]]
return Conjunction(prompts=prompts, weights=weights)
raise PromptParser.UnrecognizedOperatorException(operator)
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
#print(f"parsing fragment string for {x}")
fragment_string = x[0]
if len(fragment_string.strip()) == 0:
return Fragment('')
if in_quotes:
# escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"')
try:
result = (expression + pp.StringEnd()).parse_string(fragment_string)
#print("parsed to", result)
return result
except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e)
raise
# meaningful symbols
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quote = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
dot = pp.Literal(".").suppress()
equals = pp.Literal("=").suppress()
escaped_lparen = pp.Literal('\\(')
escaped_rparen = pp.Literal('\\)')
escaped_quote = pp.Literal('\\"')
escaped_comma = pp.Literal('\\,')
escaped_dot = pp.Literal('\\.')
escaped_plus = pp.Literal('\\+')
escaped_minus = pp.Literal('\\-')
escaped_equals = pp.Literal('\\=')
syntactic_symbols = {
'(': escaped_lparen,
')': escaped_rparen,
'"': escaped_quote,
',': escaped_comma,
'.': escaped_dot,
'+': escaped_plus,
'-': escaped_minus,
'=': escaped_equals,
}
syntactic_chars = "".join(syntactic_symbols.keys())
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
# for options
keyword = pp.Word(pp.alphanums + '_')
# a word that absolutely does not contain any meaningful syntax
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
pp.Or(syntactic_symbols.values()),
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
# build character-by-character
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
])))
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
non_syntax_word.set_name('non_syntax_word')
non_syntax_word.set_debug(False)
# a word that can contain any character at all - greedily consumes syntax, so use with care
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
free_word.set_name('free_word')
free_word.set_debug(False)
# ok here we go. forward declare some things..
attention = pp.Forward()
cross_attention_substitute = pp.Forward()
parenthesized_fragment = pp.Forward()
quoted_fragment = pp.Forward()
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
fragment_part_expressions = [
attention,
cross_attention_substitute,
parenthesized_fragment,
quoted_fragment,
non_syntax_word
]
# a fragment that is permitted to contain commas
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions + [
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
]
))
# a fragment that is not permitted to contain commas
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions
))
# a fragment in double quotes (may be nested)
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
# a fragment inside parentheses (may be nested)
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
parenthesized_fragment.set_name('parenthesized_fragment')
parenthesized_fragment.set_debug(False)
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
option = pp.Group(pp.MatchFirst([
keyword + equals + (number | keyword), # option=value
number.copy().set_parse_action(pp.token_map(str)), # weight
keyword # flag
]))
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
options.set_name('options')
options.set_debug(False)
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
# a fragment whose weight has been increased or decreased by a given amount
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
attention_explicit = (
pp.Group(potential_operator_target)
+ pp.Literal('.attend')
+ lparen
+ pp.Group(attention_weight_operator)
+ rparen
)
attention_explicit.set_parse_action(make_operator_object)
attention_implicit = (
pp.Group(potential_operator_target)
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
+ pp.Group(attention_weight_operator)
)
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
attention << (attention_explicit | attention_implicit)
attention.set_name('attention')
attention.set_debug(False)
# cross-attention control by swapping one fragment for another
cross_attention_substitute << (
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
+ lparen
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
+ rparen
)
cross_attention_substitute.set_name('cross_attention_substitute')
cross_attention_substitute.set_debug(False)
cross_attention_substitute.set_parse_action(make_operator_object)
# an entire self-contained prompt, which can be used in a Blend or Conjunction
prompt = pp.ZeroOrMore(pp.MatchFirst([
cross_attention_substitute,
attention,
quoted_fragment,
parenthesized_fragment,
free_word,
pp.White().suppress()
]))
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
# a blend/lerp between the feature vectors for two or more prompts
blend = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
+ rparen
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('bl-options').set_debug(False)
+ rparen
)
blend.set_name('blend')
blend.set_debug(False)
blend.set_parse_action(make_operator_object)
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
explicit_conjunction = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
+ rparen
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('cj-options').set_debug(False)
+ rparen
)
explicit_conjunction.set_name('explicit_conjunction')
explicit_conjunction.set_debug(False)
explicit_conjunction.set_parse_action(make_operator_object)
# by default a prompt consists of a Conjunction with a single term
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = (explicit_conjunction | implicit_conjunction)
return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', 'x` ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

380
ldm/invoke/readline.py Normal file
View File

@ -0,0 +1,380 @@
"""
Readline helper functions for invoke.py.
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from ldm.invoke.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
"""
import os
import re
import atexit
from ldm.invoke.args import Args
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except (ImportError,ModuleNotFoundError):
readline_available = False
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
TEXT_EXTENSIONS = ('.txt','.TXT')
CONFIG_EXTENSIONS = ('.yaml','.yml')
COMMANDS = (
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--threshold',
'--perlin',
'--grid','-g',
'--individual','-i',
'--save_intermediates',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--facetool','-ft',
'--facetool_strength','-G',
'--codeformer_fidelity','-cf',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t',
'--hires_fix',
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model','!del_model',
'!mask',
)
MODEL_COMMANDS = (
'!switch',
'!edit_model',
'!del_model',
)
WEIGHT_COMMANDS = (
'!import_model',
)
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
)
TEXT_PATH_COMMANDS=(
'!replay',
)
IMG_FILE_COMMANDS=(
'!fix',
'!fetch',
'!mask',
'--init_img[=\s]','-I',
'--init_mask[=\s]','-M',
'--init_color[=\s]',
'--embedding_path[=\s]',
)
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
class Completer(object):
def __init__(self, options, models=[]):
self.options = sorted(options)
self.models = sorted(models)
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
self.auto_history_active = True
self.extensions = None
return
def complete(self, text, state):
'''
Completes invoke command line.
BUG: it doesn't correctly complete files that have spaces in the name.
'''
buffer = readline.get_line_buffer()
if state == 0:
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
# looking for a model
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state)
elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def complete_extensions(self, extensions:list):
'''
If called with a list of extensions, will force completer
to do file path completions.
'''
self.extensions=extensions
def add_history(self,line):
'''
Pass thru to readline
'''
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
'''
Pass clear_history() thru to readline
'''
readline.clear_history()
def search_history(self,match:str):
'''
Like show_history() but only shows items that
contain the match string.
'''
self.show_history(match)
def remove_history_item(self,pos):
readline.remove_history_item(pos)
def add_seed(self, seed):
'''
Add a seed to the autocomplete list for display when -S is autocompleted.
'''
if seed is not None:
self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir=path
def get_line(self,index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self,index):
return readline.get_history_item(index)
def show_history(self,match=None):
'''
Print the session history using the pydoc pager
'''
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print('<empty history>')
return
for i in range(0,h_len):
line = self.get_history_item(i+1)
if match and match not in line:
continue
lines.append(f'[{i+1}] {line}')
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line
readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def _model_completions(self, text, state):
m = re.search('(!switch\s+)(\w*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.models:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(self, text, state, extensions, shortcut_ok=True):
# separate the switch from the partial path
match = re.search('^(-\w|--\w+=?)(.*)',text)
if match is None:
switch = None
partial_path = text
else:
switch,partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != '':
dir = os.path.dirname(path)
else:
dir = ''
path= os.path.join(dir,path)
dir_list = os.listdir(dir or '.')
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
dir_list += os.listdir(self.default_dir)
for node in dir_list:
if node.startswith('.') and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir,node)
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
elif os.path.isdir(full_path):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node) + '/'
)
elif node.endswith(extensions):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node)
)
return matches
class DummyCompleter(Completer):
def __init__(self,options):
super().__init__(options)
self.history = list()
def add_history(self,line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)
def get_history_item(self,index):
return self.history[index-1]
def remove_history_item(self,index):
return self.history.pop(index-1)
def set_line(self,line):
print(f'# {line}')
def get_completer(opt:Args, models=[])->Completer:
if readline_available:
completer = Completer(COMMANDS,models)
readline.set_completer(
completer.complete
)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(False)
completer.auto_history_active = False
except:
completer.auto_history_active = True
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
readline.parse_and_bind('set print-completions-horizontally off')
readline.parse_and_bind('set page-completions on')
readline.parse_and_bind('set skip-completed-text on')
readline.parse_and_bind('set show-all-if-ambiguous on')
histfile = os.path.join(os.path.expanduser(opt.outdir), '.invoke_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)
return completer

View File

@ -0,0 +1,4 @@
'''
Initialization file for the ldm.invoke.restoration package
'''
from .base import Restoration

View File

@ -0,0 +1,38 @@
class Restoration():
def __init__(self) -> None:
pass
def load_face_restore_models(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.4.pth'):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_dir, gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
print('>> GFPGAN Initialized')
else:
print('>> GFPGAN Disabled')
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print('>> CodeFormer Initialized')
else:
print('>> CodeFormer Disabled')
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_dir, gfpgan_model_path):
from ldm.invoke.restoration.gfpgan import GFPGAN
return GFPGAN(gfpgan_dir, gfpgan_model_path)
def load_codeformer(self):
from ldm.invoke.restoration.codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from ldm.invoke.restoration.realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
print('>> ESRGAN Initialized')
return esrgan;

View File

@ -0,0 +1,87 @@
import os
import torch
import numpy as np
import warnings
import sys
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
class CodeFormerRestoration():
def __init__(self,
codeformer_dir='ldm/invoke/restoration/codeformer',
codeformer_model_path='weights/codeformer.pth') -> None:
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.invoke.restoration.codeformer_arch import CodeFormer
from torchvision.transforms.functional import normalize
from PIL import Image
cf_class = CodeFormer
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True)
checkpoint = torch.load(checkpoint_path)['params_ema']
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert('RGB')
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f'\tFailed inference for CodeFormer: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@ -0,0 +1,3 @@
To use codeformer face reconstruction, you will need to copy
https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth
into this directory.

View File

@ -0,0 +1,276 @@
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List
from ldm.invoke.restoration.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator']):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -0,0 +1,81 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
class GFPGAN():
def __init__(
self,
gfpgan_dir='src/gfpgan',
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.4.pth') -> None:
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
return None
sys.path.append(os.path.abspath(gfpgan_dir))
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if self.gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
image = image.convert('RGB')
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -0,0 +1,107 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process (
self,
extents:dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback = None,
prefix = None
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_sampler()
def wrapped_callback(img,seed,**kwargs):
image_callback(img,orig_opt.seed,use_prefix=prefix,**kwargs)
result= self.generate.prompt2image(
orig_opt.prompt,
seed = orig_opt.seed, # uncomment to make it deterministic
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = extended_image.width,
height = extended_image.height,
init_img = extended_image,
strength = 0.90,
image_callback = wrapped_callback if image_callback else None,
seam_size = opt.seam_size or 96,
seam_blur = opt.seam_blur or 16,
seam_strength = opt.seam_strength or 0.7,
seam_steps = 20,
tile_size = 32,
color_match = True,
force_outpaint = True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents:dict,
) -> Image:
'''
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
'''
image = self.image
for direction in extents:
assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels/64) * 64
print(f'>> extending image {direction}ward by {pixels} pixels')
image = self._rotate(image,direction)
image = self._extend(image,pixels)
image = self._rotate(image,direction,reverse=True)
return image
def _rotate(self,image:Image,direction:str,reverse=False) -> Image:
'''
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
'''
transposes = {
'right': ['ROTATE_90','ROTATE_270'],
'bottom': ['ROTATE_180','ROTATE_180'],
'left': ['ROTATE_270','ROTATE_90']
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
extended_img.paste((0,0,0),[0,0,image.width,image.height+pixels])
extended_img.paste(image,box=(0,pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -0,0 +1,92 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback = None, prefix = None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img,seed,**kwargs):
image_callback(img,seed,use_prefix=prefix,**kwargs)
return self.generate.prompt2image(
prompt,
seed = seed,
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = opt.width,
height = opt.height,
init_img = image,
strength = 0.83,
image_callback = wrapped_callback,
prefix = prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == 'left':
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom':
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height//10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == 'left':
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom':
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@ -0,0 +1,86 @@
import torch
import warnings
import numpy as np
from PIL import Image
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler()
except Exception:
import traceback
import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
return image
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler='realesrgan',
)
# Flip the channels back to RGB
res = Image.fromarray(output[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -0,0 +1,435 @@
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
return self.main(x)

30
ldm/invoke/seamless.py Normal file
View File

@ -0,0 +1,30 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
"""
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, 'asymmetric_padding_mode'):
del m.asymmetric_padding_mode
if hasattr(m, 'asymmetric_padding'):
del m.asymmetric_padding

281
ldm/invoke/server.py Normal file
View File

@ -0,0 +1,281 @@
import argparse
import json
import copy
import base64
import mimetypes
import os
from ldm.invoke.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.invoke.pngwriter import PngWriter
from threading import Event
def build_opt(post_data, seed, gfpgan_model_exists):
opt = Args()
opt.parse_args() # initialize defaults
setattr(opt, 'prompt', post_data['prompt'])
setattr(opt, 'init_img', post_data['initimg'])
setattr(opt, 'strength', float(post_data['strength']))
setattr(opt, 'iterations', int(post_data['iterations']))
setattr(opt, 'steps', int(post_data['steps']))
setattr(opt, 'width', int(post_data['width']))
setattr(opt, 'height', int(post_data['height']))
setattr(opt, 'seamless', 'seamless' in post_data)
setattr(opt, 'fit', 'fit' in post_data)
setattr(opt, 'mask', 'mask' in post_data)
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
setattr(opt, 'sampler_name', post_data['sampler_name'])
# embiggen not practical at this point because we have no way of feeding images back into img2img
# however, this code is here against that eventuality
setattr(opt, 'embiggen', None)
setattr(opt, 'embiggen_tiles', None)
setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0)
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'progress_latents', 'progress_latents' in post_data)
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'threshold', float(post_data['threshold']))
setattr(opt, 'perlin', float(post_data['perlin']))
setattr(opt, 'hires_fix', 'hires_fix' in post_data)
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', [])
setattr(opt, 'embiggen', None)
setattr(opt, 'embiggen_tiles', None)
broken = False
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
for part in post_data['with_variations'].split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse WITH_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
opt.with_variations.append([seed, weight])
if broken:
raise CanceledException
if len(opt.with_variations) == 0:
opt.with_variations = None
return opt
class CanceledException(Exception):
pass
class DreamServer(BaseHTTPRequestHandler):
model = None
outdir = None
canceled = Event()
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/legacy_web/index.html", "rb") as content:
self.wfile.write(content.read())
elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import
self.send_response(200)
self.send_header("Content-type", "application/javascript")
self.end_headers()
config = {
'gfpgan_model_exists': self.gfpgan_model_exists
}
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/run_log.json":
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
output = []
log_file = os.path.join(self.outdir, "legacy_web_log.txt")
if os.path.exists(log_file):
with open(log_file, "r") as log:
for line in log:
url, config = line.split(": {", maxsplit=1)
config = json.loads("{" + config)
config["url"] = url.lstrip(".")
if os.path.exists(url):
output.append(config)
self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8"))
elif self.path == "/cancel":
self.canceled.set()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
else:
path_dir = os.path.dirname(self.path)
out_dir = os.path.realpath(self.outdir.rstrip('/'))
if self.path.startswith('/static/legacy_web/'):
path = '.' + self.path
elif out_dir.replace('\\', '/').endswith(path_dir):
file = os.path.basename(self.path)
path = os.path.join(self.outdir,file)
else:
self.send_response(404)
return
mime_type = mimetypes.guess_type(path)[0]
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open(path, "rb") as content:
self.wfile.write(content.read())
else:
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
# unfortunately this import can't be at the top level, since that would cause a circular import
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
self.canceled.clear()
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = config.pop('initimg_name', '')
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(self.outdir)
prefix = pngwriter.unique_prefix()
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
# LS: This repeats code in dream.py
def image_done(image, seed, upscaled=False, first_seed=None):
name = f'{prefix}.{seed}.png'
iter_opt = copy.copy(opt)
if opt.variation_amount > 0:
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
formatted_prompt = opt.dream_prompt_str(seed=seed)
path = pngwriter.save_image_and_prompt_to_png(
image,
dream_prompt = formatted_prompt,
metadata = metadata_dumps(iter_opt,
seeds = [seed],
model_hash = self.model.model_hash
),
name = name,
)
if int(config['seed']) == -1:
config['seed'] = seed
# Append post_data to log, but only once!
if not upscaled:
with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log:
log.write(f"{path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = opt.upscale or opt.facetool_strength > 0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated += 1
if upscaling_requested:
action = None
if images_generated >= opt.iterations:
if images_upscaled < opt.iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled + 1
self.wfile.write(bytes(json.dumps(
{'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
) + '\n',"utf-8"))
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
step_index = 1
def image_progress(sample, step):
if self.canceled.is_set():
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
raise CanceledException
path = None
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
nonlocal step_index
wants_progress_latents = opt.progress_latents
wants_progress_image = opt.progress_image and step % 5 == 0
if (wants_progress_image | wants_progress_latents) and step < opt.steps - 1:
image = self.model.sample_to_image(sample) if wants_progress_image \
else self.model.sample_to_lowres_estimated_image(sample)
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
step_index += 1
self.wfile.write(bytes(json.dumps(
{'event': 'step', 'step': step + 1, 'url': path}
) + '\n',"utf-8"))
try:
if opt.init_img is None:
# Run txt2img
self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = opt.init_img.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
opt1 = argparse.Namespace(**vars(opt))
opt1.init_img = "./img2img-tmp.png"
try:
# Run img2img
self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
finally:
# Remove the temp file
os.remove("./img2img-tmp.png")
except CanceledException:
print(f"Canceled.")
return
except Exception as e:
print("Error happened")
print(e)
self.wfile.write(bytes(json.dumps(
{'event': 'error',
'message': str(e),
'type': e.__class__.__name__}
) + '\n',"utf-8"))
raise e
class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)

View File

@ -4,7 +4,7 @@ import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.invoke.pngwriter import PngWriter, PromptFormatter
from threading import Event
def build_opt(post_data, seed, gfpgan_model_exists):
@ -125,7 +125,9 @@ class DreamServer(BaseHTTPRequestHandler):
self.end_headers()
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
# TODO temporarily commented out, import fails for some reason
# from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
gfpgan_model_exists = False
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
@ -148,7 +150,8 @@ class DreamServer(BaseHTTPRequestHandler):
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
def image_done(image, seed, upscaled=False, first_seed=-1, use_prefix=None):
print(f'First seed: {first_seed}')
name = f'{prefix}.{seed}.png'
iter_opt = argparse.Namespace(**vars(opt)) # copy
if opt.variation_amount > 0:

130
ldm/invoke/txt2mask.py Normal file
View File

@ -0,0 +1,130 @@
'''Makes available the Txt2Mask class, which assists in the automatic
assignment of masks via text prompt using clipseg.
Here is typical usage:
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from PIL import Image
txt2mask = Txt2Mask(self.device)
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
# this will return a grayscale Image of the segmented data
grayscale = segmented.to_grayscale()
# this will return a semi-transparent image in which the
# selected object(s) are opaque and the rest is at various
# levels of transparency
transparent = segmented.to_transparent()
# this will return a masked image suitable for use in inpainting:
mask = segmented.to_mask(threshold=0.5)
The threshold used in the call to to_mask() selects pixels for use in
the mask that exceed the indicated confidence threshold. Values range
from 0.0 to 1.0. The higher the threshold, the more confident the
algorithm is. In limited testing, I have found that values around 0.5
work fine.
'''
import torch
import numpy as np
from clipseg_models.clipseg import CLIPDensePredT
from einops import rearrange, repeat
from PIL import Image, ImageOps
from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
def __init__(self, image:Image, heatmap:torch.Tensor):
self.heatmap = heatmap
self.image = image
def to_grayscale(self,invert:bool=False)->Image:
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
def to_mask(self,threshold:float=0.5)->Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self,invert:bool=False)->Image:
transparent_image = self.image.copy()
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite. Thus invert.
gs = self.to_grayscale(not invert)
transparent_image.putalpha(gs)
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap:Image)->Image:
size = self.image.width if (self.image.width > self.image.height) else self.image.height
resized_image = heatmap.resize(
(size,size),
resample=Image.Resampling.LANCZOS
)
return resized_image.crop((0,0,self.image.width,self.image.height))
class Txt2Mask(object):
'''
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference')
self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
self.model.eval()
# initially we keep everything in cpu to conserve space
self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale:
'''
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
'''
self._to_device(self.device)
prompts = [prompt] # right now we operate on just a single prompt at a time
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0)
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
heatmap = torch.sigmoid(preds[0][0]).cpu()
self._to_device('cpu')
return SegmentedGrayscale(image, heatmap)
def _to_device(self, device):
self.model.to(device)
def _scale_and_crop(self, image:Image)->Image:
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width
else:
scale = CLIPSEG_SIZE / image.height
scaled_image.paste(
image.resize(
(int(scale * image.width),
int(scale * image.height)
),
resample=Image.Resampling.LANCZOS
),box=(0,0)
)
return scaled_image

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -0,0 +1,201 @@
from enum import Enum
import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context:
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.arguments = arguments
self.step_count = step_count
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
cross_attention_control_args: Arguments
):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
cls.inject_attention_function(model)
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
class CrossAttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
if clear_attn_slice:
m.last_attn_slice = None
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
if dim is None:
last_attn_slice = self.last_attn_slice
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
else:
last_attn_slice = self.last_attn_slice[offset]
if self.last_attn_slice_mask is None:
# just use everything
attn_slice = last_attn_slice
else:
last_attn_slice_mask = self.last_attn_slice_mask
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
this_attn_slice = attn_slice
this_attn_slice_mask = 1 - last_attn_slice_mask
attn_slice = this_attn_slice * this_attn_slice_mask + \
remapped_last_attn_slice * last_attn_slice_mask
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
else:
if self.last_attn_slice is None:
self.last_attn_slice = { offset: attn_slice }
else:
self.last_attn_slice[offset] = attn_slice
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)

View File

@ -1,293 +1,48 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device or choose_torch_device()
super().__init__(model,schedule,model.num_timesteps,device)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(dtype=torch.float32, device=self.device)
setattr(self, name, attr)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
# This routine gets called from img2img
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
img = x_T
self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
# This routine gets called from ddim_sampling() and decode()
# This is the central routine
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
def p_sample(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -295,16 +50,17 @@ class DDIMSampler(object):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(
x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index
)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
@ -351,83 +107,5 @@ class DDIMSampler(object):
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
return x_prev, pred_x0, None
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
xdec_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if img_callback:
img_callback(x_dec, i)
return x_dec

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib
from ldm.util import (
@ -106,7 +107,7 @@ class DDPM(pl.LightningModule):
], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
f' | {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
)
self.cond_stage_model = None
self.clip_denoised = clip_denoised
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@ -701,7 +702,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch
if (
self.scale_by_std
@ -820,21 +821,21 @@ class LatentDiffusion(DDPM):
)
return self.scale_factor * z
def get_learned_conditioning(self, c):
def get_learned_conditioning(self, c, **kwargs):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager
c, embedding_manager=self.embedding_manager,**kwargs
)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
c = self.cond_stage_model(c, **kwargs)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
return c
def meshgrid(self, h, w):
@ -1353,7 +1354,7 @@ class LatentDiffusion(DDPM):
num_downs = self.first_stage_model.encoder.num_resolutions - 1
rescale_latent = 2 ** (num_downs)
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
# need to rescale the tl patch coordinates to be in between (0,1)
tl_patch_coordinates = [
(
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(
self,
@ -1890,7 +1909,7 @@ class LatentDiffusion(DDPM):
N=8,
n_row=4,
sample=True,
ddim_steps=200,
ddim_steps=50,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -1,39 +1,152 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
from torch import nn
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
# at this threshold, the scheduler will stop using the Karras
# noise schedule and start using the model's schedule
STEP_THRESHOLD = 30
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
return result
maxval = 0.0 + torch.max(result).cpu().numpy()
minval = 0.0 + torch.min(result).cpu().numpy()
if maxval < threshold and minval > -threshold:
return result
if maxval > threshold:
maxval = min(max(1, scale*maxval), threshold)
if minval < -threshold:
minval = max(min(-1, scale*minval), -threshold)
return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module):
def __init__(self, model):
def __init__(self, model, threshold = 0, warmup = 0):
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(next_x, thresh)
class KSampler(object):
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device or choose_torch_device()
denoiser = K.external.CompVisDenoiser(model)
super().__init__(
denoiser,
schedule,
steps=model.num_timesteps,
)
self.sigmas = None
self.ds = None
self.s_in = None
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
if self.karras_max is None:
self.karras_max = STEP_THRESHOLD
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
):
outer_model = self.model
self.model = outer_model.inner_model
super().make_schedule(
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
)
self.model = outer_model
self.ddim_num_steps = ddim_num_steps
# we don't need both of these sigmas, but storing them here to make
# comparison easier later on
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
rho=7.,
device=self.device,
)
# most of these arguments are ignored and are only present for compatibility with
if ddim_num_steps >= self.karras_max:
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
self.sigmas = self.model_sigmas
else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
**kwargs
):
samples,_ = self.sample(
batch_size = 1,
S = t_enc,
x_T = z_enc,
shape = z_enc.shape[1:],
conditioning = cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask,
**kwargs
)
return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(
@ -58,31 +171,128 @@ class KSampler(object):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
img_callback(k_callback_values['x'],k_callback_values['i'])
sigmas = self.model.get_sigmas(S)
# if make_schedule() hasn't been called, we do it now
if self.sigmas is None:
self.make_schedule(
ddim_num_steps=S,
ddim_eta = eta,
verbose = False,
)
# sigmas are set up in make_schedule - we take the last steps items
sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image.
if x_T is not None:
x = x_T * sigmas[0]
if x0 is not None:
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
else:
x = x_T * sigmas[0]
else:
x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
sampling_result = (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
# a workaround is found.
@torch.no_grad()
def p_sample(
self,
img,
cond,
ts,
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs,
):
if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model)
extra_args = {
'cond': cond,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
if self.s_in is None:
self.s_in = img.new_ones([img.shape[0]])
if self.ds is None:
self.ds = []
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
# index starts at t_enc and works its way to zero,
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,
self.sigmas,
s_index,
s_in = self.s_in,
ds = self.ds,
extra_args=extra_args,
)
return img, None, None
# REVIEW THIS METHOD: it has never been tested. In particular,
# we should not be multiplying by self.sigmas[0] if we
# are at an intermediate step in img2img. See similar in
# sample() which does work.
def get_initial_image(self,x_T,shape,steps):
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
if x_T is not None:
return x_T + x
else:
return x
def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None
self.s_in = None
def q_sample(self,x0,ts):
'''
Overrides parent method to return the q_sample of the inner model.
'''
return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -4,303 +4,49 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
class PLMSSampler(object):
class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device if device else choose_torch_device()
super().__init__(model,schedule,model.num_timesteps, device)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
img = x_T
self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
# this is the essential routine
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
def p_sample(
self,
x, # image, called 'img' elsewhere
c, # conditioning, called 'cond' elsewhere
t, # timesteps, called 'ts' elsewhere
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=[],
t_next=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -309,18 +55,15 @@ class PLMSSampler(object):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(

View File

@ -0,0 +1,450 @@
'''
ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
'''
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class Sampler(object):
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
self.model = model
self.ddim_timesteps = None
self.ddpm_num_timesteps = steps
self.schedule = schedule
self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
# This method was copied over from ddim.py and probably does stuff that is
# ddim-specific. Disentangle at some point.
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
):
self.total_steps = ddim_num_steps
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def sample(
self,
S, # S is steps
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None:
self.make_schedule(
ddim_num_steps=S,
ddim_eta = eta,
verbose = False,
)
ts = self.get_timesteps(S)
# sampling
C, H, W = shape
shape = (batch_size, C, H, W)
samples, intermediates = self.do_sampling(
conditioning,
shape,
timesteps=ts,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
**kwargs
)
return samples, intermediates
@torch.no_grad()
def do_sampling(
self,
cond,
shape,
timesteps=None,
x_T=None,
ddim_use_original_steps=False,
callback=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
**kwargs
):
b = shape[0]
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps=steps
iterator = tqdm(
time_range,
desc=f'{self.__class__.__name__}',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all
intermediates = {'x_inter': [img], 'pred_x0': [img]}
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(b,),
step,
device=self.device,
dtype=torch.long
)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
step_count=steps
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(img,i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
# NOTE that decode() and sample() are almost the same code, and do the same thing.
# The variable names are changed in order to be confusing.
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
all_timesteps_count = None,
**kwargs
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
ts_next = torch.full(
(x_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
outs = self.p_sample(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
t_next = ts_next,
step_count=len(self.ddim_timesteps)
)
x_dec, pred_x0, e_t = outs
if img_callback:
img_callback(x_dec,i)
return x_dec
def get_initial_image(self,x_T,shape,timesteps=None):
if x_T is None:
return torch.randn(shape, device=self.device)
else:
return x_T
def p_sample(
self,
img,
cond,
ts,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
steps=None,
):
raise NotImplementedError("p_sample() must be implemented in a descendent class")
def prepare_to_sample(self,t_enc,**kwargs):
'''
Hook that will be called right before the very first invocation of p_sample()
to allow subclass to do additional initialization. t_enc corresponds to the actual
number of steps that will be run, and may be less than total steps if img2img is
active.
'''
pass
def get_timesteps(self,ddim_steps):
'''
The ddim and plms samplers work on timesteps. This method is called after
ddim_timesteps are created in make_schedule(), and selects the portion of
timesteps that will be used for sampling, depending on the t_enc in img2img.
'''
return self.ddim_timesteps[:ddim_steps]
def q_sample(self,x0,ts):
'''
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
return self.model.inner_model.q_sample(x0,ts)
'''
return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key
def uses_inpainting_model(self)->bool:
return self.conditioning_key() in ('hybrid','concat')
def adjust_settings(self,**kwargs):
'''
This is a catch-all method for adjusting any instance variables
after the sampler is instantiated. No type-checking performed
here, so use with care!
'''
for k in kwargs.keys():
try:
setattr(self,k,kwargs[k])
except AttributeError:
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')

View File

@ -0,0 +1,225 @@
from math import ceil
from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
class InvokeAIDiffuserComponent:
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
'''
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
self.cross_attention_control_args = cross_attention_control_args
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.model = model
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict()
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]])
for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model)
return unconditioned_next_x, conditioned_next_x
except RuntimeError:
# make sure we clean out the attention slices we're storing on the model
# TODO don't store things on the model
CrossAttentionControl.clear_requests(self.model)
raise
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -1,5 +1,7 @@
from inspect import isfunction
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn, einsum
@ -90,7 +92,7 @@ class LinearAttention(nn.Module):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
@ -168,116 +171,114 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout)
)
if not torch.cuda.is_available():
mem_av = psutil.virtual_memory().available / (1024**3)
if mem_av > 32:
self.einsum_op = self.einsum_op_v1
elif mem_av > 12:
self.einsum_op = self.einsum_op_v2
else:
self.einsum_op = self.einsum_op_v3
del mem_av
else:
self.einsum_op = self.einsum_op_v4
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
# mps 64-128 GB
def einsum_op_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
else:
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
# needs around half of that slice_size to not generate noise
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
self.attention_slice_wrangler = None
# mps 16-32 GB (can be optimized)
def einsum_op_v2(self, q, k, v, r1):
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
# mps 8 GB
def einsum_op_v3(self, q, k, v, r1):
slice_size = 1
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
del s2
return r1
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
# cuda
def einsum_op_v4(self, q, k, v, r1):
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
mem_required = tensor_size * 2.5
steps = 1
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
q = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.get_attention_mem_efficient(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class BasicTransformerBlock(nn.Module):
@ -297,9 +298,9 @@ class BasicTransformerBlock(nn.Module):
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x += self.attn1(self.norm1(x.clone()))
x += self.attn2(self.norm2(x.clone()), context=context)
x += self.ff(self.norm3(x.clone()))
return x

View File

@ -3,6 +3,7 @@ import gc
import math
import torch
import torch.nn as nn
from torch.nn.functional import silu
import numpy as np
from einops import rearrange
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@ -53,9 +49,15 @@ class Upsample(nn.Module):
padding=1)
def forward(self, x):
cpu_m1_cond = True if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and \
x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False
if cpu_m1_cond:
x = x.to('cpu') # send to cpu
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if cpu_m1_cond:
x = x.to('mps') # return to mps
return x
@ -121,30 +123,25 @@ class ResnetBlock(nn.Module):
padding=0)
def forward(self, x, temb):
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
x_size = x.size()
if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0:
self.to('cpu')
x = x.to('cpu')
else:
self.to('mps')
x = x.to('mps')
h = self.norm1(x)
h = silu(h)
h = self.conv1(h)
if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(silu(temb))[:,:,None,None]
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
h = self.norm2(h)
h = silu(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@ -152,7 +149,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h8
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
@ -209,8 +206,7 @@ class AttnBlock(nn.Module):
h_ = torch.zeros_like(k, device=q.device)
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
if device_type == 'cuda':
if q.device.type == 'cuda':
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
@ -263,7 +259,7 @@ class AttnBlock(nn.Module):
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
print(f" | Making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
@ -382,7 +378,7 @@ class Model(nn.Module):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
@ -416,7 +412,7 @@ class Model(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -513,7 +509,7 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -539,7 +535,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
print(" | Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
@ -599,22 +595,16 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h1 = self.conv_in(z)
h = self.conv_in(z)
# middle
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# prepare for up sampling
device_type = 'mps' if h.device.type == 'mps' else 'cuda'
gc.collect()
if device_type == 'cuda':
if h.device.type == 'cuda':
torch.cuda.empty_cache()
# upsampling
@ -622,33 +612,19 @@ class Decoder(nn.Module):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
t = h
h = self.up[i_level].attn[i_block](t)
del t
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
t = h
h = self.up[i_level].upsample(t)
del t
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
h = self.norm_out(h)
h = silu(h)
h = self.conv_out(h)
if self.tanh_out:
t = h
h = torch.tanh(t)
del t
h = torch.tanh(h)
return h
@ -683,7 +659,7 @@ class SimpleDecoder(nn.Module):
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
h = silu(h)
x = self.conv_out(h)
return x
@ -731,7 +707,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -907,7 +883,7 @@ class FirstStagePostProcessor(nn.Module):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
z = silu(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)

View File

@ -64,7 +64,11 @@ def make_ddim_timesteps(
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
if c < 1:
c = 1
# remove 1 final step to prevent index out of bound error
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))[:-1]
elif ddim_discr_method == 'quad':
ddim_timesteps = (
(
@ -81,8 +85,7 @@ def make_ddim_timesteps(
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
# steps_out = ddim_timesteps + 1
steps_out = ddim_timesteps
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
@ -252,12 +255,6 @@ def normalization(channels):
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)

View File

@ -82,7 +82,9 @@ class EmbeddingManager(nn.Module):
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280
# per bug report #572
#token_dim = 1280
token_dim = 768
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(
@ -167,9 +169,14 @@ class EmbeddingManager(nn.Module):
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if torch.cuda.is_available():
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
else:
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token
)
if placeholder_rows.nelement() == 0:
continue

View File

@ -1,3 +1,5 @@
import math
import torch
import torch.nn as nn
from functools import partial
@ -5,7 +7,7 @@ import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.dream.devices import choose_torch_device
from ldm.invoke.devices import choose_torch_device
from ldm.modules.x_transformer import (
Encoder,
@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs):
return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens"
def forward(self, text: list, **kwargs):
'''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied.
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
'''
if self.fragment_weights_key not in kwargs:
# fallback to base class implementation
return super().forward(text, **kwargs)
fragment_weights = kwargs[self.fragment_weights_key]
# self.transformer doesn't like receiving "fragment_weights" as an argument
kwargs.pop(self.fragment_weights_key)
should_return_tokens = False
if self.return_tokens_key in kwargs:
should_return_tokens = kwargs.get(self.return_tokens_key, False)
# self.transformer doesn't like having extra kwargs
kwargs.pop(self.return_tokens_key)
batch_z = None
batch_tokens = None
for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens:
return batch_z, batch_tokens
else:
return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
if include_start_and_end_markers:
return tokens
else:
return [x[1:-1] for x in tokens]
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
'''
:param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return:
'''
# empty is meaningful
if len(fragments) == 0 and len(weights) == 0:
fragments = ['']
weights = [1]
item_encodings = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=True,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = []
#print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
#print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
#print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if (len(all_tokens) + 2) > self.max_length:
excess_token_count = (len(all_tokens) + 2) - self.max_length
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
per_token_weights = per_token_weights[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens)
all_tokens.insert(0, self.tokenizer.bos_token_id)
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
per_token_weights.insert(0, 1)
per_token_weights.extend([1] * pad_length)
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
'''
Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0],
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)['input_ids'].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])
return weighted_z
else:
original_mean = z.mean()
z *= batch_weights_expanded
after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean
z *= mean_correction_factor
return z
class FrozenCLIPTextEmbedder(nn.Module):
"""

View File

@ -2,6 +2,7 @@ import importlib
import torch
import numpy as np
import math
from collections import abc
from einops import rearrange
from functools import partial
@ -74,7 +75,7 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params
@ -212,3 +213,25 @@ def parallel_data_prefetch(
return out
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
rand_val = torch.rand(res[0]+1, res[1]+1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)