mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into fix_generate.py
This commit is contained in:
@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
@ -1,96 +0,0 @@
|
||||
'''
|
||||
This module handles the generation of the conditioning tensors, including management of
|
||||
weighted subprompts.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c() get the conditioned and unconditioned latent
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
'''
|
||||
import re
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
uc = model.get_learned_conditioning([''])
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
prompt, skip_normalize
|
||||
)
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / len(parsed_prompts)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
@ -1,20 +0,0 @@
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def choose_autocast_device(device):
|
||||
'''Returns an autocast compatible device from a torch device'''
|
||||
device_type = device.type # this returns 'mps' on M1
|
||||
# autocast only supports cuda or cpu
|
||||
if device_type in ('cuda','cpu'):
|
||||
return device_type,autocast
|
||||
else:
|
||||
return 'cpu',nullcontext
|
@ -1,4 +0,0 @@
|
||||
'''
|
||||
Initialization file for the ldm.dream.generator package
|
||||
'''
|
||||
from .base import Generator
|
@ -1,72 +0,0 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img descends from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self,model):
|
||||
super().__init__(model)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if not isinstance(sampler,DDIMSampler):
|
||||
print(
|
||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
with scope(device_type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
if device.type == 'mps':
|
||||
return torch.randn_like(init_latent, device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn_like(init_latent, device=device)
|
@ -1,77 +0,0 @@
|
||||
'''
|
||||
ldm.dream.generator.inpaint descends from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self,model):
|
||||
self.init_latent = None
|
||||
super().__init__(model)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if not isinstance(sampler,DDIMSampler):
|
||||
print(
|
||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
with scope(device_type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
|
@ -1,127 +0,0 @@
|
||||
"""
|
||||
Readline helper functions for dream.py (linux and mac only).
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
|
||||
readline_available = True
|
||||
except:
|
||||
readline_available = False
|
||||
|
||||
|
||||
class Completer:
|
||||
def __init__(self, options):
|
||||
self.options = sorted(options)
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if text.startswith(('-I', '--init_img','-M','--init_mask')):
|
||||
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
|
||||
|
||||
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
|
||||
return self._path_completions(text, state, ())
|
||||
|
||||
response = None
|
||||
if state == 0:
|
||||
# This is the first time for this text, so build a match list.
|
||||
if text:
|
||||
self.matches = [
|
||||
s for s in self.options if s and s.startswith(text)
|
||||
]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def _path_completions(self, text, state, extensions):
|
||||
# get the path so far
|
||||
# TODO: replace this mess with a regular expression match
|
||||
if text.startswith('-I'):
|
||||
path = text.replace('-I', '', 1).lstrip()
|
||||
elif text.startswith('--init_img='):
|
||||
path = text.replace('--init_img=', '', 1).lstrip()
|
||||
elif text.startswith('--init_mask='):
|
||||
path = text.replace('--init_mask=', '', 1).lstrip()
|
||||
elif text.startswith('-M'):
|
||||
path = text.replace('-M', '', 1).lstrip()
|
||||
else:
|
||||
path = text
|
||||
|
||||
matches = list()
|
||||
|
||||
path = os.path.expanduser(path)
|
||||
if len(path) == 0:
|
||||
matches.append(text + './')
|
||||
else:
|
||||
dir = os.path.dirname(path)
|
||||
dir_list = os.listdir(dir)
|
||||
for n in dir_list:
|
||||
if n.startswith('.') and len(n) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, n)
|
||||
if full_path.startswith(path):
|
||||
if os.path.isdir(full_path):
|
||||
matches.append(
|
||||
os.path.join(os.path.dirname(text), n) + '/'
|
||||
)
|
||||
elif n.endswith(extensions):
|
||||
matches.append(os.path.join(os.path.dirname(text), n))
|
||||
|
||||
try:
|
||||
response = matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
|
||||
if readline_available:
|
||||
readline.set_completer(
|
||||
Completer(
|
||||
[
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
'--iterations','-n',
|
||||
'--width','-W','--height','-H',
|
||||
'--cfg_scale','-C',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
'--sampler','-A','-m',
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--gfpgan_strength','-G',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','t',
|
||||
]
|
||||
).complete
|
||||
)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
|
||||
histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
atexit.register(readline.write_history_file, histfile)
|
1069
ldm/generate.py
1069
ldm/generate.py
File diff suppressed because it is too large
Load Diff
@ -1,167 +0,0 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from scripts.dream import create_argv_parser
|
||||
|
||||
arg_parser = create_argv_parser()
|
||||
opt = arg_parser.parse_args()
|
||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
gfpgan_model_exists = os.path.isfile(model_path)
|
||||
|
||||
def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
gfpgan = None
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
if not gfpgan_model_exists:
|
||||
raise Exception('GFPGAN model not found at path ' + model_path)
|
||||
|
||||
sys.path.append(os.path.abspath(opt.gfpgan_dir))
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
bg_upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
|
||||
gfpgan = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=upsampler_scale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=bg_upsampler,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if gfpgan is None:
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
||||
)
|
||||
return image
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gfpgan = None
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
||||
if bg_upsampler == 'realesrgan':
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
model_path = {
|
||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
}
|
||||
|
||||
if upsampler_scale not in model_path:
|
||||
return None
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
if upsampler_scale == 4:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
if upsampler_scale == 2:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=upsampler_scale,
|
||||
model_path=model_path[upsampler_scale],
|
||||
model=model,
|
||||
tile=bg_tile,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
)
|
||||
else:
|
||||
bg_upsampler = None
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
|
||||
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
output, img_mode = upsampler.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler=opt.gfpgan_bg_upsampler,
|
||||
)
|
||||
|
||||
res = Image.fromarray(output)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
1105
ldm/invoke/args.py
Normal file
1105
ldm/invoke/args.py
Normal file
File diff suppressed because it is too large
Load Diff
195
ldm/invoke/conditioning.py
Normal file
195
ldm/invoke/conditioning.py
Normal file
@ -0,0 +1,195 @@
|
||||
'''
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
|
||||
'''
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
|
||||
pp = PromptParser()
|
||||
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
# we don't support conjunctions for now
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
print(f">> Parsed prompt to {parsed_prompt}")
|
||||
|
||||
conditioning = None
|
||||
cac_args:CrossAttentionControl.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for i,flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_opcodes = []
|
||||
edit_options = []
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = CrossAttentionControl.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
if cac_args is not None:
|
||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
original_tokens = original_tokens.cpu().numpy()[0]
|
||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if log_tokens:
|
||||
text = " ".join(fragments)
|
||||
log_tokenization(text, model, display_label=log_display_label)
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
def flatten_hybrid_conditioning(uncond, cond):
|
||||
'''
|
||||
This handles the choice between a conditional conditioning
|
||||
that is a tensor (used by cross attention) vs one that has additional
|
||||
dimensions as well, as used by 'hybrid'
|
||||
'''
|
||||
assert isinstance(uncond, dict)
|
||||
assert isinstance(cond, dict)
|
||||
cond_flattened = dict()
|
||||
for k in cond:
|
||||
if isinstance(cond[k], list):
|
||||
cond_flattened[k] = [
|
||||
torch.cat([uncond[k][i], cond[k][i]])
|
||||
for i in range(len(cond[k]))
|
||||
]
|
||||
else:
|
||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||
return uncond, cond_flattened
|
||||
|
||||
|
27
ldm/invoke/devices.py
Normal file
27
ldm/invoke/devices.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def choose_precision(device) -> str:
|
||||
'''Returns an appropriate precision for the given torch device'''
|
||||
if device.type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
|
||||
return 'float16'
|
||||
return 'float32'
|
||||
|
||||
def choose_autocast(precision):
|
||||
'''Returns an autocast context or nullcontext for the given precision string'''
|
||||
# float16 currently requires autocast to avoid errors like:
|
||||
# 'expected scalar type Half but found Float'
|
||||
if precision == 'autocast' or precision == 'float16':
|
||||
return autocast
|
||||
return nullcontext
|
4
ldm/invoke/generator/__init__.py
Normal file
4
ldm/invoke/generator/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.generator package
|
||||
'''
|
||||
from .base import Generator
|
@ -1,26 +1,36 @@
|
||||
'''
|
||||
Base class for ldm.dream.generator.*
|
||||
Base class for ldm.invoke.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import traceback
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class Generator():
|
||||
def __init__(self,model):
|
||||
self.model = model
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
def __init__(self, model, precision):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@ -35,23 +45,31 @@ class Generator():
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
make_image = self.get_make_image(
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
seed = seed if seed is not None else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
with scope(device_type), self.model.ema_scope():
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
@ -63,21 +81,30 @@ class Generator():
|
||||
x_T = initial_noise
|
||||
else:
|
||||
seed_everything(seed)
|
||||
if self.model.device.type == 'mps':
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
except:
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print(traceback.format_exc())
|
||||
|
||||
# make_image will do the equivalent of get_noise itself
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@ -89,6 +116,29 @@ class Generator():
|
||||
)
|
||||
return Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
# write an approximate RGB image from latent samples for a single step to PNG
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[ 0.3444, 0.1385, 0.0670], # L1
|
||||
[ 0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445] # L4
|
||||
], dtype=samples.dtype, device=samples.device)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
@ -115,6 +165,10 @@ class Generator():
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
@ -156,3 +210,48 @@ class Generator():
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
features.to(self.model.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = Image.open(CAUTION_IMG)
|
||||
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or '.'
|
||||
if not os.path.exists(dirname):
|
||||
print(f'** creating directory {dirname}')
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
501
ldm/invoke/generator/embiggen.py
Normal file
501
ldm/invoke/generator/embiggen.py
Normal file
@ -0,0 +1,501 @@
|
||||
'''
|
||||
ldm.invoke.generator.embiggen descends from ldm.invoke.generator
|
||||
and generates with ldm.invoke.generator.img2img
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from PIL import Image
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
step_callback = step_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
with scope(self.model.device.type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model,self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert('RGB')
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||
width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||
height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = Image.linear_gradient('L').rotate(
|
||||
90).resize((overlap_size_x, height))
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
#Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
|
||||
#Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||
else:
|
||||
print(
|
||||
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
else:
|
||||
print(
|
||||
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(
|
||||
newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = seed,
|
||||
sampler = DDIMSampler(self.model, device=self.model.device),
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = None, # called only after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_image = newinitimage, # notice that init_image is different from init_img
|
||||
mask_image = None,
|
||||
strength = strength,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||
outputsuperimage = Image.new(
|
||||
"RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert('RGBA'), (0, 0))
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert('RGBA')
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i *
|
||||
(width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
# end of function declaration
|
||||
return make_image
|
90
ldm/invoke/generator/img2img.py
Normal file
90
ldm/invoke/generator/img2img.py
Normal file
@ -0,0 +1,90 @@
|
||||
'''
|
||||
ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import PIL
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(init_latent, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = init_latent.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
315
ldm/invoke/generator/inpaint.py
Normal file
315
ldm/invoke/generator/inpaint.py
Normal file
@ -0,0 +1,315 @@
|
||||
'''
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
self.mask_blur_radius = 0
|
||||
super().__init__(model, precision)
|
||||
|
||||
# Outpaint support code
|
||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
)
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != 'RGBA':
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,strength,
|
||||
noise
|
||||
) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0
|
||||
)
|
||||
|
||||
result = make_image(noise)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image
|
||||
|
||||
# Fill missing areas of original image
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
|
||||
# Create init tensor
|
||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
print(
|
||||
f">> Using recommended DDIM sampler for inpainting."
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||
if inpaint_replace > 0.0:
|
||||
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||
z_enc = z_enc * mask_image + masked_region
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
|
||||
result = self.sample_to_image(samples)
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
|
||||
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(image, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
|
||||
return corrected_result
|
153
ldm/invoke/generator/omnibus.py
Normal file
153
ldm/invoke/generator/omnibus.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image, ImageOps
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
|
||||
class Omnibus(Img2Img,Txt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
init_image = None,
|
||||
mask_image = None,
|
||||
strength = None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, Image.Image):
|
||||
if init_image.mode != 'RGB':
|
||||
init_image = init_image.convert('RGB')
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||
|
||||
t_enc = steps
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
masked_image = init_image
|
||||
|
||||
else: # txt2img
|
||||
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
self.init_latent = init_image
|
||||
height = init_image.shape[2]
|
||||
width = init_image.shape[3]
|
||||
model = self.model
|
||||
|
||||
def make_image(x_T):
|
||||
with torch.no_grad():
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
|
||||
batch = self.make_batch_sd(
|
||||
init_image,
|
||||
mask_image,
|
||||
masked_image,
|
||||
prompt=prompt,
|
||||
device=model.device,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != model.masked_image_key:
|
||||
bchw = [num_samples, 4, height//8, width//8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = cond,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc_full,
|
||||
eta = 1.0,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def make_batch_sd(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
masked_image,
|
||||
prompt,
|
||||
device,
|
||||
num_samples=1):
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [prompt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
if self.init_latent is not None:
|
||||
height = self.init_latent.shape[2]
|
||||
width = self.init_latent.shape[3]
|
||||
return Txt2Img.get_noise(self,width,height)
|
@ -1,24 +1,27 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img inherits from ldm.dream.generator
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self,model):
|
||||
super().__init__(model)
|
||||
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,**kwargs):
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
self.perlin = perlin
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -27,6 +30,12 @@ class Txt2Img(Generator):
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
@ -36,9 +45,15 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
@ -47,15 +62,19 @@ class Txt2Img(Generator):
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
186
ldm/invoke/generator/txt2img2img.py
Normal file
186
ldm/invoke/generator/txt2img2img.py
Normal file
@ -0,0 +1,186 @@
|
||||
'''
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
|
||||
image = self.sample_to_image(samples)
|
||||
image = InitImageResizer(image).resize(width, height)
|
||||
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
image = image.to(self.model.device)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
samples = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(image)
|
||||
) # move back to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
ddim_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
z_enc = ddim_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=self.get_noise(width,height,False)
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = ddim_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = Omnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
else:
|
||||
return make_image
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
trained_square = 512 * 512
|
||||
actual_square = width * height
|
||||
scale = math.sqrt(trained_square / actual_square)
|
||||
scaled_width = math.ceil(scale * width / 64) * 64
|
||||
scaled_height = math.ceil(scale * height / 64) * 64
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
66
ldm/invoke/log.py
Normal file
66
ldm/invoke/log.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Functions for better format logging
|
||||
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
|
||||
1 write_log_message -- Writes a message to the console
|
||||
2 write_log_files -- Writes a message to files
|
||||
2.1 write_log_default -- File in plain text
|
||||
2.2 write_log_txt -- File in txt format
|
||||
2.3 write_log_markdown -- File in markdown format
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def write_log(results, log_path, file_types, output_cntr):
|
||||
"""
|
||||
logs the name of the output image, prompt, and prompt args to the terminal and files
|
||||
"""
|
||||
output_cntr = write_log_message(results, output_cntr)
|
||||
write_log_files(results, log_path, file_types)
|
||||
return output_cntr
|
||||
|
||||
|
||||
def write_log_message(results, output_cntr):
|
||||
"""logs to the terminal"""
|
||||
if len(results) == 0:
|
||||
return output_cntr
|
||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
if len(log_lines)>1:
|
||||
subcntr = 1
|
||||
for l in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {l}", end="")
|
||||
subcntr += 1
|
||||
else:
|
||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||
return output_cntr+1
|
||||
|
||||
def write_log_files(results, log_path, file_types):
|
||||
for file_type in file_types:
|
||||
if file_type == "txt":
|
||||
write_log_txt(log_path, results)
|
||||
elif file_type == "md" or file_type == "markdown":
|
||||
write_log_markdown(log_path, results)
|
||||
else:
|
||||
print(f"'{file_type}' format is not supported, so write in plain text")
|
||||
write_log_default(log_path, results, file_type)
|
||||
|
||||
|
||||
def write_log_default(log_path, results, file_type):
|
||||
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
|
||||
file.writelines(plain_txt_lines)
|
||||
|
||||
|
||||
def write_log_txt(log_path, results):
|
||||
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + ".txt", "a", encoding="utf-8") as file:
|
||||
file.writelines(txt_lines)
|
||||
|
||||
|
||||
def write_log_markdown(log_path, results):
|
||||
md_lines = []
|
||||
for path, prompt in results:
|
||||
file_name = os.path.basename(path)
|
||||
md_lines.append(f"## {file_name}\n\n\n{prompt}\n")
|
||||
with open(log_path + ".md", "a", encoding="utf-8") as file:
|
||||
file.writelines(md_lines)
|
370
ldm/invoke/model_cache.py
Normal file
370
ldm/invoke/model_cache.py
Normal file
@ -0,0 +1,370 @@
|
||||
'''
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import transformers
|
||||
import traceback
|
||||
import os
|
||||
from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
'''
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.max_loaded_models = max_loaded_models
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
self.models[model_name]['model'] = requested_model
|
||||
self.models[model_name]['width'] = width
|
||||
self.models[model_name]['height'] = height
|
||||
self.models[model_name]['hash'] = hash
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(traceback.format_exc())
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return None
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
'model':requested_model,
|
||||
'width':width,
|
||||
'height':height,
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default',False):
|
||||
return model_name
|
||||
return None
|
||||
|
||||
def set_default_model(self,model_name:str):
|
||||
'''
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
for model in self.models:
|
||||
self.models[model].pop('default',None)
|
||||
self.models[model_name]['default'] = True
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
},
|
||||
model_name2: { etc }
|
||||
'''
|
||||
result = {}
|
||||
for name in self.config:
|
||||
try:
|
||||
description = self.config[name].description
|
||||
except ConfigAttributeError:
|
||||
description = '<no description>'
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
elif name in self.models:
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
result[name]={}
|
||||
result[name]['status']=status
|
||||
result[name]['description']=description
|
||||
return result
|
||||
|
||||
def print_models(self):
|
||||
'''
|
||||
Print a table of models, their descriptions, and load status
|
||||
'''
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
||||
if models[name]['status'] == 'active':
|
||||
print(f'\033[1m{line}\033[0m')
|
||||
else:
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) ->bool:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
omega = self.config
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
return True
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
# check that all the required fields are present
|
||||
for field in ('description','weights','height','width','config'):
|
||||
assert field in model_attributes, f'required field {field} is missing'
|
||||
|
||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||
config = omega[model_name] if model_name in omega else {}
|
||||
for field in model_attributes:
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae',None)
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae:
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
'''
|
||||
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
message = f'>> Offloading {model_name} to CPU'
|
||||
print(message)
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _make_cache_room(self):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
def print_vram_usage(self):
|
||||
if self._has_cuda:
|
||||
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
def commit(self,config_file_path:str):
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
'''
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,config_file_path)
|
||||
|
||||
def preamble(self):
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return '''# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
'''
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str):
|
||||
self.offload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
return model.to('cpu')
|
||||
else:
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.to(self.device)
|
||||
model.first_stage_model.to(self.device)
|
||||
model.cond_stage_model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
'''
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str):
|
||||
'''
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
'''
|
||||
try:
|
||||
self.stack.remove(model_name)
|
||||
except ValueError:
|
||||
pass
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data):
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
@ -3,12 +3,13 @@ Two helper classes for dealing with PNG images and their path names.
|
||||
PngWriter -- Converts Images generated by T2I into PNGs, finds
|
||||
appropriate names for them, and writes prompt metadata
|
||||
into the PNG.
|
||||
PromptFormatter -- Utility for converting a Namespace of prompt parameters
|
||||
back into a formatted prompt string with command-line switches.
|
||||
|
||||
Exports function retrieve_metadata(path)
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from PIL import PngImagePlugin
|
||||
import json
|
||||
from PIL import PngImagePlugin, Image
|
||||
|
||||
# -------------------image generation utils-----
|
||||
|
||||
@ -32,13 +33,39 @@ class PngWriter:
|
||||
|
||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||
# returns full path of output
|
||||
def save_image_and_prompt_to_png(self, image, prompt, name):
|
||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
||||
path = os.path.join(self.outdir, name)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('Dream', prompt)
|
||||
image.save(path, 'PNG', pnginfo=info)
|
||||
info.add_text('Dream', dream_prompt)
|
||||
if metadata:
|
||||
info.add_text('sd-metadata', json.dumps(metadata))
|
||||
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
def retrieve_metadata(self,img_basename):
|
||||
'''
|
||||
Given a PNG filename stored in outdir, returns the "sd-metadata"
|
||||
metadata stored there, as a dict
|
||||
'''
|
||||
path = os.path.join(self.outdir,img_basename)
|
||||
all_metadata = retrieve_metadata(path)
|
||||
return all_metadata['sd-metadata']
|
||||
|
||||
def retrieve_metadata(img_path):
|
||||
'''
|
||||
Given a path to a PNG image, returns the "sd-metadata"
|
||||
metadata stored there, as a dict
|
||||
'''
|
||||
im = Image.open(img_path)
|
||||
md = im.text.get('sd-metadata', '{}')
|
||||
dream_prompt = im.text.get('Dream', '')
|
||||
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt}
|
||||
|
||||
def write_metadata(img_path:str, meta:dict):
|
||||
im = Image.open(img_path)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('sd-metadata', json.dumps(meta))
|
||||
im.save(img_path,'PNG',pnginfo=info)
|
||||
|
||||
class PromptFormatter:
|
||||
def __init__(self, t2i, opt):
|
667
ldm/invoke/prompt_parser.py
Normal file
667
ldm/invoke/prompt_parser.py
Normal file
@ -0,0 +1,667 @@
|
||||
import string
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
'''
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
weighted subprompts.
|
||||
|
||||
Useful class exports:
|
||||
|
||||
PromptParser - parses prompts
|
||||
|
||||
Useful function exports:
|
||||
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
'''
|
||||
|
||||
class Prompt():
|
||||
"""
|
||||
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
||||
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
||||
that are to be blended together are stored individuall as Prompt objects.
|
||||
|
||||
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
||||
to produce a FlattenedPrompt.
|
||||
"""
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class FlattenedPrompt():
|
||||
"""
|
||||
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
||||
"""
|
||||
def __init__(self, parts: list=[]):
|
||||
self.children = []
|
||||
for part in parts:
|
||||
self.append(part)
|
||||
|
||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||
# verify type correctness
|
||||
if type(fragment) is list:
|
||||
for x in fragment:
|
||||
self.append(x)
|
||||
elif issubclass(type(fragment), BaseFragment):
|
||||
self.children.append(fragment)
|
||||
elif type(fragment) is tuple:
|
||||
# upgrade tuples to Fragments
|
||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||
else:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
|
||||
class Fragment(BaseFragment):
|
||||
"""
|
||||
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
||||
"""
|
||||
def __init__(self, text: str, weight: float=1):
|
||||
assert(type(text) is str)
|
||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
||||
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
|
||||
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
|
||||
self.text = text
|
||||
self.weight = float(weight)
|
||||
|
||||
def __repr__(self):
|
||||
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
||||
def __eq__(self, other):
|
||||
return type(other) is Fragment \
|
||||
and other.text == self.text \
|
||||
and other.weight == self.weight
|
||||
|
||||
class Attention():
|
||||
"""
|
||||
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
||||
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
||||
|
||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||
"""
|
||||
def __init__(self, weight: float, children: list):
|
||||
if type(weight) is not float:
|
||||
raise PromptParser.ParsingException(
|
||||
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
||||
self.weight = weight
|
||||
if type(children) is not list:
|
||||
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
||||
assert(type(children) is list)
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:{self.children} * {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
class CrossAttentionControlledFragment(BaseFragment):
|
||||
pass
|
||||
|
||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
"""
|
||||
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
||||
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
||||
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
||||
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
||||
|
||||
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
||||
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
||||
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
||||
or it may represent a larger portion of the token sequence:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
||||
edited=[Fragment('a smiling dog sitting on a car')])
|
||||
|
||||
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
||||
FlattenedPrompt([
|
||||
Fragment('a'),
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
||||
Fragment('sitting on a car')
|
||||
])
|
||||
"""
|
||||
def __init__(self, original: list, edited: list, options: dict=None):
|
||||
self.original = original if len(original)>0 else [Fragment('')]
|
||||
self.edited = edited if len(edited)>0 else [Fragment('')]
|
||||
|
||||
default_options = {
|
||||
's_start': 0.0,
|
||||
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
|
||||
't_start': 0.0,
|
||||
't_end': 1.0
|
||||
}
|
||||
merged_options = default_options
|
||||
if options is not None:
|
||||
shape_freedom = options.pop('shape_freedom', None)
|
||||
if shape_freedom is not None:
|
||||
# high shape freedom = SD can do what it wants with the shape of the object
|
||||
# high shape freedom => s_end = 0
|
||||
# low shape freedom => s_end = 1
|
||||
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
|
||||
# and there is very little perceptible difference as s_end increases above 0.5
|
||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
||||
# -> cube root and subtract from 1.0
|
||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
||||
#print('converted shape_freedom argument to', merged_options)
|
||||
merged_options.update(options)
|
||||
|
||||
self.options = merged_options
|
||||
|
||||
def __repr__(self):
|
||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlSubstitute \
|
||||
and other.original == self.original \
|
||||
and other.edited == self.edited \
|
||||
and other.options == self.options
|
||||
|
||||
|
||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||
def __init__(self, fragment: Fragment):
|
||||
self.fragment = fragment
|
||||
def __repr__(self):
|
||||
return "CrossAttentionControlAppend:",self.fragment
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlAppend \
|
||||
and other.fragment == self.fragment
|
||||
|
||||
|
||||
|
||||
class Conjunction():
|
||||
"""
|
||||
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
||||
by weighted sum in latent space.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
||||
self.prompts = [x if (type(x) is Prompt
|
||||
or type(x) is Blend
|
||||
or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
|
||||
def __repr__(self):
|
||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Conjunction \
|
||||
and other.prompts == self.prompts \
|
||||
and other.weights == self.weights
|
||||
|
||||
|
||||
class Blend():
|
||||
"""
|
||||
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
||||
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
||||
Prompts.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for p in prompts:
|
||||
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
for f in p.children:
|
||||
if isinstance(f, CrossAttentionControlSubstitute):
|
||||
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
|
||||
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x)
|
||||
for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
return other.__repr__() == self.__repr__()
|
||||
|
||||
|
||||
class PromptParser():
|
||||
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
class UnrecognizedOperatorException(ParsingException):
|
||||
def __init__(self, operator:str):
|
||||
super().__init__("Unrecognized operator: " + operator)
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||
|
||||
|
||||
def parse_conjunction(self, prompt: str) -> Conjunction:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a Conjunction representing the parsed results.
|
||||
'''
|
||||
#print(f"!!parsing '{prompt}'")
|
||||
|
||||
if len(prompt.strip()) == 0:
|
||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
||||
|
||||
root = self.conjunction.parse_string(prompt)
|
||||
#print(f"'{prompt}' parsed to root", root)
|
||||
#fused = fuse_fragments(parts)
|
||||
#print("fused to", fused)
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||
"""
|
||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||
that can be readily tokenized without the need to walk a complex tree structure.
|
||||
|
||||
:param root: The Conjunction to flatten.
|
||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||
"""
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
if type(x) is CrossAttentionControlSubstitute:
|
||||
original_fused = fuse_fragments(x.original)
|
||||
edited_fused = fuse_fragments(x.edited)
|
||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
||||
else:
|
||||
last_weight = result[-1].weight \
|
||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||
else None
|
||||
this_text = x.text
|
||||
this_weight = x.weight
|
||||
if last_weight is not None and last_weight == this_weight:
|
||||
last_text = result[-1].text
|
||||
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
||||
else:
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
verbose and print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults or type(node) is list:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif type(node) is Attention:
|
||||
# if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for index,c in enumerate(node.children):
|
||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
for prompt in node.prompts:
|
||||
# prompt is a list
|
||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
|
||||
elif type(node) is Prompt:
|
||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
||||
flattened_prompt = []
|
||||
for child in node.children:
|
||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
verbose and print("flattening", root)
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
|
||||
verbose and print("flattened to", flattened_parts)
|
||||
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
|
||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||
def make_operator_object(x):
|
||||
#print('making operator for', x)
|
||||
target = x[0]
|
||||
operator = x[1]
|
||||
arguments = x[2]
|
||||
if operator == '.attend':
|
||||
weight_raw = arguments[0]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
return Attention(weight=weight, children=[x for x in x[0]])
|
||||
elif operator == '.swap':
|
||||
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
||||
elif operator == '.blend':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights_raw = x[2]
|
||||
normalize_weights = True
|
||||
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
||||
normalize_weights = False
|
||||
weights_raw = weights_raw[:-1]
|
||||
weights = [float(w[0]) for w in weights_raw]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
||||
elif operator == '.and' or operator == '.add':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights = [float(w[0]) for w in x[2]]
|
||||
return Conjunction(prompts=prompts, weights=weights)
|
||||
|
||||
raise PromptParser.UnrecognizedOperatorException(operator)
|
||||
|
||||
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
||||
#print(f"parsing fragment string for {x}")
|
||||
fragment_string = x[0]
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
if in_quotes:
|
||||
# escape unescaped quotes
|
||||
fragment_string = fragment_string.replace('"', '\\"')
|
||||
|
||||
try:
|
||||
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
||||
#print("parsed to", result)
|
||||
return result
|
||||
except pp.ParseException as e:
|
||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||
raise
|
||||
|
||||
# meaningful symbols
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quote = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
dot = pp.Literal(".").suppress()
|
||||
equals = pp.Literal("=").suppress()
|
||||
|
||||
escaped_lparen = pp.Literal('\\(')
|
||||
escaped_rparen = pp.Literal('\\)')
|
||||
escaped_quote = pp.Literal('\\"')
|
||||
escaped_comma = pp.Literal('\\,')
|
||||
escaped_dot = pp.Literal('\\.')
|
||||
escaped_plus = pp.Literal('\\+')
|
||||
escaped_minus = pp.Literal('\\-')
|
||||
escaped_equals = pp.Literal('\\=')
|
||||
|
||||
syntactic_symbols = {
|
||||
'(': escaped_lparen,
|
||||
')': escaped_rparen,
|
||||
'"': escaped_quote,
|
||||
',': escaped_comma,
|
||||
'.': escaped_dot,
|
||||
'+': escaped_plus,
|
||||
'-': escaped_minus,
|
||||
'=': escaped_equals,
|
||||
}
|
||||
syntactic_chars = "".join(syntactic_symbols.keys())
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
|
||||
# for options
|
||||
keyword = pp.Word(pp.alphanums + '_')
|
||||
|
||||
# a word that absolutely does not contain any meaningful syntax
|
||||
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
||||
pp.Or(syntactic_symbols.values()),
|
||||
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
||||
# build character-by-character
|
||||
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
||||
])))
|
||||
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
||||
non_syntax_word.set_name('non_syntax_word')
|
||||
non_syntax_word.set_debug(False)
|
||||
|
||||
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
||||
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
||||
free_word.set_name('free_word')
|
||||
free_word.set_debug(False)
|
||||
|
||||
|
||||
# ok here we go. forward declare some things..
|
||||
attention = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
|
||||
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
||||
fragment_part_expressions = [
|
||||
attention,
|
||||
cross_attention_substitute,
|
||||
parenthesized_fragment,
|
||||
quoted_fragment,
|
||||
non_syntax_word
|
||||
]
|
||||
# a fragment that is permitted to contain commas
|
||||
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions + [
|
||||
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
||||
]
|
||||
))
|
||||
# a fragment that is not permitted to contain commas
|
||||
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions
|
||||
))
|
||||
|
||||
# a fragment in double quotes (may be nested)
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
||||
|
||||
# a fragment inside parentheses (may be nested)
|
||||
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment')
|
||||
parenthesized_fragment.set_debug(False)
|
||||
|
||||
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
||||
option = pp.Group(pp.MatchFirst([
|
||||
keyword + equals + (number | keyword), # option=value
|
||||
number.copy().set_parse_action(pp.token_map(str)), # weight
|
||||
keyword # flag
|
||||
]))
|
||||
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
||||
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
||||
options.set_name('options')
|
||||
options.set_debug(False)
|
||||
|
||||
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
||||
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
||||
|
||||
# a fragment whose weight has been increased or decreased by a given amount
|
||||
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
||||
attention_explicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.Literal('.attend')
|
||||
+ lparen
|
||||
+ pp.Group(attention_weight_operator)
|
||||
+ rparen
|
||||
)
|
||||
attention_explicit.set_parse_action(make_operator_object)
|
||||
attention_implicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
||||
+ pp.Group(attention_weight_operator)
|
||||
)
|
||||
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
||||
attention << (attention_explicit | attention_implicit)
|
||||
attention.set_name('attention')
|
||||
attention.set_debug(False)
|
||||
|
||||
# cross-attention control by swapping one fragment for another
|
||||
cross_attention_substitute << (
|
||||
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
||||
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
||||
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||
cross_attention_substitute.set_debug(False)
|
||||
cross_attention_substitute.set_parse_action(make_operator_object)
|
||||
|
||||
|
||||
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
||||
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
||||
cross_attention_substitute,
|
||||
attention,
|
||||
quoted_fragment,
|
||||
parenthesized_fragment,
|
||||
free_word,
|
||||
pp.White().suppress()
|
||||
]))
|
||||
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
||||
|
||||
|
||||
# a blend/lerp between the feature vectors for two or more prompts
|
||||
blend = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
blend.set_name('blend')
|
||||
blend.set_debug(False)
|
||||
blend.set_parse_action(make_operator_object)
|
||||
|
||||
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
||||
explicit_conjunction = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
explicit_conjunction.set_name('explicit_conjunction')
|
||||
explicit_conjunction.set_debug(False)
|
||||
explicit_conjunction.set_parse_action(make_operator_object)
|
||||
|
||||
# by default a prompt consists of a Conjunction with a single term
|
||||
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = (explicit_conjunction | implicit_conjunction)
|
||||
|
||||
return conjunction, prompt
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', 'x` ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
380
ldm/invoke/readline.py
Normal file
380
ldm/invoke/readline.py
Normal file
@ -0,0 +1,380 @@
|
||||
"""
|
||||
Readline helper functions for invoke.py.
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from ldm.invoke.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
from ldm.invoke.args import Args
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
readline_available = True
|
||||
except (ImportError,ModuleNotFoundError):
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
|
||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
COMMANDS = (
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
'--iterations','-n',
|
||||
'--width','-W','--height','-H',
|
||||
'--cfg_scale','-C',
|
||||
'--threshold',
|
||||
'--perlin',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--save_intermediates',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
'--sampler','-A','-m',
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--facetool','-ft',
|
||||
'--facetool_strength','-G',
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
)
|
||||
TEXT_PATH_COMMANDS=(
|
||||
'!replay',
|
||||
)
|
||||
IMG_FILE_COMMANDS=(
|
||||
'!fix',
|
||||
'!fetch',
|
||||
'!mask',
|
||||
'--init_img[=\s]','-I',
|
||||
'--init_mask[=\s]','-M',
|
||||
'--init_color[=\s]',
|
||||
'--embedding_path[=\s]',
|
||||
)
|
||||
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models=[]):
|
||||
self.options = sorted(options)
|
||||
self.models = sorted(models)
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
'''
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
'''
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
|
||||
|
||||
elif re.search(text_regexp,buffer):
|
||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
s for s in self.options if s and s.startswith(text)
|
||||
]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions:list):
|
||||
'''
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
'''
|
||||
self.extensions=extensions
|
||||
|
||||
def add_history(self,line):
|
||||
'''
|
||||
Pass thru to readline
|
||||
'''
|
||||
if not self.auto_history_active:
|
||||
readline.add_history(line)
|
||||
|
||||
def clear_history(self):
|
||||
'''
|
||||
Pass clear_history() thru to readline
|
||||
'''
|
||||
readline.clear_history()
|
||||
|
||||
def search_history(self,match:str):
|
||||
'''
|
||||
Like show_history() but only shows items that
|
||||
contain the match string.
|
||||
'''
|
||||
self.show_history(match)
|
||||
|
||||
def remove_history_item(self,pos):
|
||||
readline.remove_history_item(pos)
|
||||
|
||||
def add_seed(self, seed):
|
||||
'''
|
||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
||||
'''
|
||||
if seed is not None:
|
||||
self.seeds.add(str(seed))
|
||||
|
||||
def set_default_dir(self, path):
|
||||
self.default_dir=path
|
||||
|
||||
def get_line(self,index):
|
||||
try:
|
||||
line = self.get_history_item(index)
|
||||
except IndexError:
|
||||
return None
|
||||
return line
|
||||
|
||||
def get_current_history_length(self):
|
||||
return readline.get_current_history_length()
|
||||
|
||||
def get_history_item(self,index):
|
||||
return readline.get_history_item(index)
|
||||
|
||||
def show_history(self,match=None):
|
||||
'''
|
||||
Print the session history using the pydoc pager
|
||||
'''
|
||||
import pydoc
|
||||
lines = list()
|
||||
h_len = self.get_current_history_length()
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
return
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
if match and match not in line:
|
||||
continue
|
||||
lines.append(f'[{i+1}] {line}')
|
||||
pydoc.pager('\n'.join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def add_model(self,model_name:str)->None:
|
||||
'''
|
||||
add a model name to the completion list
|
||||
'''
|
||||
self.models.append(model_name)
|
||||
|
||||
def del_model(self,model_name:str)->None:
|
||||
'''
|
||||
removes a model name from the completion list
|
||||
'''
|
||||
self.models.remove(model_name)
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
|
||||
matches = list()
|
||||
for s in self.seeds:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def _path_completions(self, text, state, extensions, shortcut_ok=True):
|
||||
# separate the switch from the partial path
|
||||
match = re.search('^(-\w|--\w+=?)(.*)',text)
|
||||
if match is None:
|
||||
switch = None
|
||||
partial_path = text
|
||||
else:
|
||||
switch,partial_path = match.groups()
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
matches = list()
|
||||
path = os.path.expanduser(partial_path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
dir = path
|
||||
elif os.path.dirname(path) != '':
|
||||
dir = os.path.dirname(path)
|
||||
else:
|
||||
dir = ''
|
||||
path= os.path.join(dir,path)
|
||||
|
||||
dir_list = os.listdir(dir or '.')
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
|
||||
dir_list += os.listdir(self.default_dir)
|
||||
|
||||
for node in dir_list:
|
||||
if node.startswith('.') and len(node) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, node)
|
||||
|
||||
if not (node.endswith(extensions) or os.path.isdir(full_path)):
|
||||
continue
|
||||
|
||||
if not full_path.startswith(path):
|
||||
continue
|
||||
|
||||
if switch is None:
|
||||
match_path = os.path.join(dir,node)
|
||||
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
|
||||
elif os.path.isdir(full_path):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node) + '/'
|
||||
)
|
||||
elif node.endswith(extensions):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node)
|
||||
)
|
||||
return matches
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
def add_history(self,line):
|
||||
self.history.append(line)
|
||||
|
||||
def clear_history(self):
|
||||
self.history = list()
|
||||
|
||||
def get_current_history_length(self):
|
||||
return len(self.history)
|
||||
|
||||
def get_history_item(self,index):
|
||||
return self.history[index-1]
|
||||
|
||||
def remove_history_item(self,index):
|
||||
return self.history.pop(index-1)
|
||||
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(False)
|
||||
completer.auto_history_active = False
|
||||
except:
|
||||
completer.auto_history_active = True
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
|
||||
histfile = os.path.join(os.path.expanduser(opt.outdir), '.invoke_history')
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
else:
|
||||
completer = DummyCompleter(COMMANDS)
|
||||
return completer
|
4
ldm/invoke/restoration/__init__.py
Normal file
4
ldm/invoke/restoration/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.restoration package
|
||||
'''
|
||||
from .base import Restoration
|
38
ldm/invoke/restoration/base.py
Normal file
38
ldm/invoke/restoration/base.py
Normal file
@ -0,0 +1,38 @@
|
||||
class Restoration():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def load_face_restore_models(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.4.pth'):
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_dir, gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print('>> GFPGAN Initialized')
|
||||
else:
|
||||
print('>> GFPGAN Disabled')
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print('>> CodeFormer Initialized')
|
||||
else:
|
||||
print('>> CodeFormer Disabled')
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self, gfpgan_dir, gfpgan_model_path):
|
||||
from ldm.invoke.restoration.gfpgan import GFPGAN
|
||||
return GFPGAN(gfpgan_dir, gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from ldm.invoke.restoration.codeformer import CodeFormerRestoration
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_esrgan(self, esrgan_bg_tile=400):
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
87
ldm/invoke/restoration/codeformer.py
Normal file
87
ldm/invoke/restoration/codeformer.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='ldm/invoke/restoration/codeformer',
|
||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.invoke.restoration.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
|
||||
cf_class = CodeFormer
|
||||
|
||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
||||
image = image.convert('RGB')
|
||||
# Codeformer expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
|
||||
face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
cf = None
|
||||
|
||||
return res
|
3
ldm/invoke/restoration/codeformer/weights/README
Normal file
3
ldm/invoke/restoration/codeformer/weights/README
Normal file
@ -0,0 +1,3 @@
|
||||
To use codeformer face reconstruction, you will need to copy
|
||||
https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth
|
||||
into this directory.
|
276
ldm/invoke/restoration/codeformer_arch.py
Normal file
276
ldm/invoke/restoration/codeformer_arch.py
Normal file
@ -0,0 +1,276 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from ldm.invoke.restoration.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
81
ldm/invoke/restoration/gfpgan.py
Normal file
81
ldm/invoke/restoration/gfpgan.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class GFPGAN():
|
||||
def __init__(
|
||||
self,
|
||||
gfpgan_dir='src/gfpgan',
|
||||
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.4.pth') -> None:
|
||||
|
||||
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
||||
return None
|
||||
sys.path.append(os.path.abspath(gfpgan_dir))
|
||||
|
||||
def model_exists(self):
|
||||
return os.path.isfile(self.model_path)
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
self.gfpgan = GFPGANer(
|
||||
model_path=self.model_path,
|
||||
upscale=1,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
||||
)
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
# GFPGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
_, _, restored_img = self.gfpgan.enhance(
|
||||
bgr_image_array,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.gfpgan = None
|
||||
|
||||
return res
|
107
ldm/invoke/restoration/outcrop.py
Normal file
107
ldm/invoke/restoration/outcrop.py
Normal file
@ -0,0 +1,107 @@
|
||||
import warnings
|
||||
import math
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
self,
|
||||
image,
|
||||
generate, # current generate object
|
||||
):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process (
|
||||
self,
|
||||
extents:dict,
|
||||
opt, # current options
|
||||
orig_opt, # ones originally used to generate the image
|
||||
image_callback = None,
|
||||
prefix = None
|
||||
):
|
||||
# grow and mask the image
|
||||
extended_image = self._extend_all(extents)
|
||||
|
||||
# switch samplers temporarily
|
||||
curr_sampler = self.generate.sampler
|
||||
self.generate.sampler_name = opt.sampler_name
|
||||
self.generate._set_sampler()
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
image_callback(img,orig_opt.seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
result= self.generate.prompt2image(
|
||||
orig_opt.prompt,
|
||||
seed = orig_opt.seed, # uncomment to make it deterministic
|
||||
sampler = self.generate.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.generate.ddim_eta,
|
||||
width = extended_image.width,
|
||||
height = extended_image.height,
|
||||
init_img = extended_image,
|
||||
strength = 0.90,
|
||||
image_callback = wrapped_callback if image_callback else None,
|
||||
seam_size = opt.seam_size or 96,
|
||||
seam_blur = opt.seam_blur or 16,
|
||||
seam_strength = opt.seam_strength or 0.7,
|
||||
seam_steps = 20,
|
||||
tile_size = 32,
|
||||
color_match = True,
|
||||
force_outpaint = True, # this just stops the warning about erased regions
|
||||
)
|
||||
|
||||
# swap sampler back
|
||||
self.generate.sampler = curr_sampler
|
||||
return result
|
||||
|
||||
def _extend_all(
|
||||
self,
|
||||
extents:dict,
|
||||
) -> Image:
|
||||
'''
|
||||
Extend the image in direction ('top','bottom','left','right') by
|
||||
the indicated value. The image canvas is extended, and the empty
|
||||
rectangular section will be filled with a blurred copy of the
|
||||
adjacent image.
|
||||
'''
|
||||
image = self.image
|
||||
for direction in extents:
|
||||
assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"'
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels/64) * 64
|
||||
print(f'>> extending image {direction}ward by {pixels} pixels')
|
||||
image = self._rotate(image,direction)
|
||||
image = self._extend(image,pixels)
|
||||
image = self._rotate(image,direction,reverse=True)
|
||||
return image
|
||||
|
||||
def _rotate(self,image:Image,direction:str,reverse=False) -> Image:
|
||||
'''
|
||||
Rotates image so that the area to extend is always at the top top.
|
||||
Simplifies logic later. The reverse argument, if true, will undo the
|
||||
previous transpose.
|
||||
'''
|
||||
transposes = {
|
||||
'right': ['ROTATE_90','ROTATE_270'],
|
||||
'bottom': ['ROTATE_180','ROTATE_180'],
|
||||
'left': ['ROTATE_270','ROTATE_90']
|
||||
}
|
||||
if direction not in transposes:
|
||||
return image
|
||||
transpose = transposes[direction][1 if reverse else 0]
|
||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
||||
|
||||
def _extend(self,image:Image,pixels:int)-> Image:
|
||||
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
|
||||
|
||||
extended_img.paste((0,0,0),[0,0,image.width,image.height+pixels])
|
||||
extended_img.paste(image,box=(0,pixels))
|
||||
|
||||
# now make the top part transparent to use as a mask
|
||||
alpha = extended_img.getchannel('A')
|
||||
alpha.paste(0,(0,0,extended_img.width,pixels))
|
||||
extended_img.putalpha(alpha)
|
||||
|
||||
return extended_img
|
92
ldm/invoke/restoration/outpaint.py
Normal file
92
ldm/invoke/restoration/outpaint.py
Normal file
@ -0,0 +1,92 @@
|
||||
import warnings
|
||||
import math
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class Outpaint(object):
|
||||
def __init__(self, image, generate):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process(self, opt, old_opt, image_callback = None, prefix = None):
|
||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
||||
|
||||
seed = old_opt.seed
|
||||
prompt = old_opt.prompt
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
image_callback(img,seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
|
||||
return self.generate.prompt2image(
|
||||
prompt,
|
||||
seed = seed,
|
||||
sampler = self.generate.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.generate.ddim_eta,
|
||||
width = opt.width,
|
||||
height = opt.height,
|
||||
init_img = image,
|
||||
strength = 0.83,
|
||||
image_callback = wrapped_callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
|
||||
def _create_outpaint_image(self, image, direction_args):
|
||||
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
|
||||
|
||||
if len(direction_args) == 1:
|
||||
direction = direction_args[0]
|
||||
pixels = None
|
||||
elif len(direction_args) == 2:
|
||||
direction = direction_args[0]
|
||||
pixels = int(direction_args[1])
|
||||
|
||||
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
||||
|
||||
image = image.convert("RGBA")
|
||||
# we always extend top, but rotate to extend along the requested side
|
||||
if direction == 'left':
|
||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
||||
elif direction == 'bottom':
|
||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
||||
|
||||
pixels = image.height//2 if pixels is None else int(pixels)
|
||||
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
|
||||
|
||||
# the top part of the image is taken from the source image mirrored
|
||||
# coordinates (0,0) are the upper left corner of an image
|
||||
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
|
||||
top = top.crop((0, top.height - pixels, top.width, top.height))
|
||||
|
||||
# setting all alpha of the top part to 0
|
||||
alpha = top.getchannel("A")
|
||||
alpha.paste(0, (0, 0, top.width, top.height))
|
||||
top.putalpha(alpha)
|
||||
|
||||
# taking the bottom from the original image
|
||||
bottom = image.crop((0, 0, image.width, image.height - pixels))
|
||||
|
||||
new_img = image.copy()
|
||||
new_img.paste(top, (0, 0))
|
||||
new_img.paste(bottom, (0, pixels))
|
||||
|
||||
# create a 10% dither in the middle
|
||||
dither = min(image.height//10, pixels)
|
||||
for x in range(0, image.width, 2):
|
||||
for y in range(pixels - dither, pixels + dither):
|
||||
(r, g, b, a) = new_img.getpixel((x, y))
|
||||
new_img.putpixel((x, y), (r, g, b, 0))
|
||||
|
||||
# let's rotate back again
|
||||
if direction == 'left':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
||||
elif direction == 'bottom':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
||||
|
||||
return new_img
|
||||
|
86
ldm/invoke/restoration/realesrgan.py
Normal file
86
ldm/invoke/restoration/realesrgan.py
Normal file
@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ESRGAN():
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model=model,
|
||||
tile=self.bg_tile_size,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
)
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler()
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
output, _ = upsampler.enhance(
|
||||
bgr_image_array,
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler='realesrgan',
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(output[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
435
ldm/invoke/restoration/vqgan_arch.py
Normal file
435
ldm/invoke/restoration/vqgan_arch.py
Normal file
@ -0,0 +1,435 @@
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
30
ldm/invoke/seamless.py
Normal file
30
ldm/invoke/seamless.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch.nn as nn
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
|
||||
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
|
||||
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, 'asymmetric_padding_mode'):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, 'asymmetric_padding'):
|
||||
del m.asymmetric_padding
|
281
ldm/invoke/server.py
Normal file
281
ldm/invoke/server.py
Normal file
@ -0,0 +1,281 @@
|
||||
import argparse
|
||||
import json
|
||||
import copy
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from ldm.invoke.args import Args, metadata_dumps
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
from threading import Event
|
||||
|
||||
def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
opt = Args()
|
||||
opt.parse_args() # initialize defaults
|
||||
setattr(opt, 'prompt', post_data['prompt'])
|
||||
setattr(opt, 'init_img', post_data['initimg'])
|
||||
setattr(opt, 'strength', float(post_data['strength']))
|
||||
setattr(opt, 'iterations', int(post_data['iterations']))
|
||||
setattr(opt, 'steps', int(post_data['steps']))
|
||||
setattr(opt, 'width', int(post_data['width']))
|
||||
setattr(opt, 'height', int(post_data['height']))
|
||||
setattr(opt, 'seamless', 'seamless' in post_data)
|
||||
setattr(opt, 'fit', 'fit' in post_data)
|
||||
setattr(opt, 'mask', 'mask' in post_data)
|
||||
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
|
||||
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
|
||||
setattr(opt, 'sampler_name', post_data['sampler_name'])
|
||||
|
||||
# embiggen not practical at this point because we have no way of feeding images back into img2img
|
||||
# however, this code is here against that eventuality
|
||||
setattr(opt, 'embiggen', None)
|
||||
setattr(opt, 'embiggen_tiles', None)
|
||||
|
||||
setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0)
|
||||
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
|
||||
setattr(opt, 'progress_images', 'progress_images' in post_data)
|
||||
setattr(opt, 'progress_latents', 'progress_latents' in post_data)
|
||||
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
||||
setattr(opt, 'threshold', float(post_data['threshold']))
|
||||
setattr(opt, 'perlin', float(post_data['perlin']))
|
||||
setattr(opt, 'hires_fix', 'hires_fix' in post_data)
|
||||
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
||||
setattr(opt, 'with_variations', [])
|
||||
setattr(opt, 'embiggen', None)
|
||||
setattr(opt, 'embiggen_tiles', None)
|
||||
|
||||
broken = False
|
||||
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
|
||||
for part in post_data['with_variations'].split(','):
|
||||
seed_and_weight = part.split(':')
|
||||
if len(seed_and_weight) != 2:
|
||||
print(f'could not parse WITH_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
try:
|
||||
seed = int(seed_and_weight[0])
|
||||
weight = float(seed_and_weight[1])
|
||||
except ValueError:
|
||||
print(f'could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
opt.with_variations.append([seed, weight])
|
||||
|
||||
if broken:
|
||||
raise CanceledException
|
||||
|
||||
if len(opt.with_variations) == 0:
|
||||
opt.with_variations = None
|
||||
|
||||
return opt
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
model = None
|
||||
outdir = None
|
||||
canceled = Event()
|
||||
|
||||
def do_GET(self):
|
||||
if self.path == "/":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
with open("./static/legacy_web/index.html", "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
elif self.path == "/config.js":
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/javascript")
|
||||
self.end_headers()
|
||||
config = {
|
||||
'gfpgan_model_exists': self.gfpgan_model_exists
|
||||
}
|
||||
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
||||
elif self.path == "/run_log.json":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
output = []
|
||||
|
||||
log_file = os.path.join(self.outdir, "legacy_web_log.txt")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r") as log:
|
||||
for line in log:
|
||||
url, config = line.split(": {", maxsplit=1)
|
||||
config = json.loads("{" + config)
|
||||
config["url"] = url.lstrip(".")
|
||||
if os.path.exists(url):
|
||||
output.append(config)
|
||||
|
||||
self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8"))
|
||||
elif self.path == "/cancel":
|
||||
self.canceled.set()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes('{}', 'utf8'))
|
||||
else:
|
||||
path_dir = os.path.dirname(self.path)
|
||||
out_dir = os.path.realpath(self.outdir.rstrip('/'))
|
||||
if self.path.startswith('/static/legacy_web/'):
|
||||
path = '.' + self.path
|
||||
elif out_dir.replace('\\', '/').endswith(path_dir):
|
||||
file = os.path.basename(self.path)
|
||||
path = os.path.join(self.outdir,file)
|
||||
else:
|
||||
self.send_response(404)
|
||||
return
|
||||
mime_type = mimetypes.guess_type(path)[0]
|
||||
if mime_type is not None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", mime_type)
|
||||
self.end_headers()
|
||||
with open(path, "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
||||
def do_POST(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = json.loads(self.rfile.read(content_length))
|
||||
opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
|
||||
|
||||
self.canceled.clear()
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
# across images generated by each call to prompt2img(), so we define it in
|
||||
# the outer scope of image_done()
|
||||
config = post_data.copy() # Shallow copy
|
||||
config['initimg'] = config.pop('initimg_name', '')
|
||||
|
||||
images_generated = 0 # helps keep track of when upscaling is started
|
||||
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||
pngwriter = PngWriter(self.outdir)
|
||||
|
||||
prefix = pngwriter.unique_prefix()
|
||||
# if upscaling is requested, then this will be called twice, once when
|
||||
# the images are first generated, and then again when after upscaling
|
||||
# is complete. The upscaling replaces the original file, so the second
|
||||
# entry should not be inserted into the image list.
|
||||
# LS: This repeats code in dream.py
|
||||
def image_done(image, seed, upscaled=False, first_seed=None):
|
||||
name = f'{prefix}.{seed}.png'
|
||||
iter_opt = copy.copy(opt)
|
||||
if opt.variation_amount > 0:
|
||||
this_variation = [[seed, opt.variation_amount]]
|
||||
if opt.with_variations is None:
|
||||
iter_opt.with_variations = this_variation
|
||||
else:
|
||||
iter_opt.with_variations = opt.with_variations + this_variation
|
||||
iter_opt.variation_amount = 0
|
||||
formatted_prompt = opt.dream_prompt_str(seed=seed)
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image,
|
||||
dream_prompt = formatted_prompt,
|
||||
metadata = metadata_dumps(iter_opt,
|
||||
seeds = [seed],
|
||||
model_hash = self.model.model_hash
|
||||
),
|
||||
name = name,
|
||||
)
|
||||
|
||||
if int(config['seed']) == -1:
|
||||
config['seed'] = seed
|
||||
# Append post_data to log, but only once!
|
||||
if not upscaled:
|
||||
with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log:
|
||||
log.write(f"{path}: {json.dumps(config)}\n")
|
||||
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
# control state of the "postprocessing..." message
|
||||
upscaling_requested = opt.upscale or opt.facetool_strength > 0
|
||||
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
if upscaled:
|
||||
images_upscaled += 1
|
||||
else:
|
||||
images_generated += 1
|
||||
if upscaling_requested:
|
||||
action = None
|
||||
if images_generated >= opt.iterations:
|
||||
if images_upscaled < opt.iterations:
|
||||
action = 'upscaling-started'
|
||||
else:
|
||||
action = 'upscaling-done'
|
||||
if action:
|
||||
x = images_upscaled + 1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
|
||||
step_index = 1
|
||||
def image_progress(sample, step):
|
||||
if self.canceled.is_set():
|
||||
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
|
||||
raise CanceledException
|
||||
path = None
|
||||
# since rendering images is moderately expensive, only render every 5th image
|
||||
# and don't bother with the last one, since it'll render anyway
|
||||
nonlocal step_index
|
||||
|
||||
wants_progress_latents = opt.progress_latents
|
||||
wants_progress_image = opt.progress_image and step % 5 == 0
|
||||
|
||||
if (wants_progress_image | wants_progress_latents) and step < opt.steps - 1:
|
||||
image = self.model.sample_to_image(sample) if wants_progress_image \
|
||||
else self.model.sample_to_lowres_estimated_image(sample)
|
||||
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
|
||||
name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
|
||||
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
|
||||
path = step_writer.save_image_and_prompt_to_png(image, dream_prompt=metadata, name=name)
|
||||
step_index += 1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': 'step', 'step': step + 1, 'url': path}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
try:
|
||||
if opt.init_img is None:
|
||||
# Run txt2img
|
||||
self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
initimg = opt.init_img.split(",")[1] # Ignore mime type
|
||||
f.write(base64.b64decode(initimg))
|
||||
opt1 = argparse.Namespace(**vars(opt))
|
||||
opt1.init_img = "./img2img-tmp.png"
|
||||
|
||||
try:
|
||||
# Run img2img
|
||||
self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
|
||||
finally:
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
except CanceledException:
|
||||
print(f"Canceled.")
|
||||
return
|
||||
except Exception as e:
|
||||
print("Error happened")
|
||||
print(e)
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': 'error',
|
||||
'message': str(e),
|
||||
'type': e.__class__.__name__}
|
||||
) + '\n',"utf-8"))
|
||||
raise e
|
||||
|
||||
|
||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||
def __init__(self, server_address):
|
||||
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)
|
@ -4,7 +4,7 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||
from ldm.invoke.pngwriter import PngWriter, PromptFormatter
|
||||
from threading import Event
|
||||
|
||||
def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
@ -125,7 +125,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
# TODO temporarily commented out, import fails for some reason
|
||||
# from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
gfpgan_model_exists = False
|
||||
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = json.loads(self.rfile.read(content_length))
|
||||
@ -148,7 +150,8 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
# the images are first generated, and then again when after upscaling
|
||||
# is complete. The upscaling replaces the original file, so the second
|
||||
# entry should not be inserted into the image list.
|
||||
def image_done(image, seed, upscaled=False):
|
||||
def image_done(image, seed, upscaled=False, first_seed=-1, use_prefix=None):
|
||||
print(f'First seed: {first_seed}')
|
||||
name = f'{prefix}.{seed}.png'
|
||||
iter_opt = argparse.Namespace(**vars(opt)) # copy
|
||||
if opt.variation_amount > 0:
|
130
ldm/invoke/txt2mask.py
Normal file
130
ldm/invoke/txt2mask.py
Normal file
@ -0,0 +1,130 @@
|
||||
'''Makes available the Txt2Mask class, which assists in the automatic
|
||||
assignment of masks via text prompt using clipseg.
|
||||
|
||||
Here is typical usage:
|
||||
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from PIL import Image
|
||||
|
||||
txt2mask = Txt2Mask(self.device)
|
||||
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
|
||||
|
||||
# this will return a grayscale Image of the segmented data
|
||||
grayscale = segmented.to_grayscale()
|
||||
|
||||
# this will return a semi-transparent image in which the
|
||||
# selected object(s) are opaque and the rest is at various
|
||||
# levels of transparency
|
||||
transparent = segmented.to_transparent()
|
||||
|
||||
# this will return a masked image suitable for use in inpainting:
|
||||
mask = segmented.to_mask(threshold=0.5)
|
||||
|
||||
The threshold used in the call to to_mask() selects pixels for use in
|
||||
the mask that exceed the indicated confidence threshold. Values range
|
||||
from 0.0 to 1.0. The higher the threshold, the more confident the
|
||||
algorithm is. In limited testing, I have found that values around 0.5
|
||||
work fine.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
|
||||
CLIP_VERSION = 'ViT-B/16'
|
||||
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
|
||||
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image:Image, heatmap:torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self,invert:bool=False)->Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
|
||||
def to_mask(self,threshold:float=0.5)->Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||
|
||||
def to_transparent(self,invert:bool=False)->Image:
|
||||
transparent_image = self.image.copy()
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite. Thus invert.
|
||||
gs = self.to_grayscale(not invert)
|
||||
transparent_image.putalpha(gs)
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap:Image)->Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize(
|
||||
(size,size),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
return resized_image.crop((0,0,self.image.width,self.image.height))
|
||||
|
||||
class Txt2Mask(object):
|
||||
'''
|
||||
Create new Txt2Mask object. The optional device argument can be one of
|
||||
'cuda', 'mps' or 'cpu'.
|
||||
'''
|
||||
def __init__(self,device='cpu',refined=False):
|
||||
print('>> Initializing clipseg model for text to mask inference')
|
||||
self.device = device
|
||||
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
|
||||
self.model.eval()
|
||||
# initially we keep everything in cpu to conserve space
|
||||
self.model.to('cpu')
|
||||
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
||||
'''
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
'''
|
||||
self._to_device(self.device)
|
||||
prompts = [prompt] # right now we operate on just a single prompt at a time
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||
])
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert('RGB')
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
img = transform(img).unsqueeze(0)
|
||||
|
||||
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
||||
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
||||
self._to_device('cpu')
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _to_device(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
def _scale_and_crop(self, image:Image)->Image:
|
||||
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
else:
|
||||
scale = CLIPSEG_SIZE / image.height
|
||||
scaled_image.paste(
|
||||
image.resize(
|
||||
(int(scale * image.width),
|
||||
int(scale * image.height)
|
||||
),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
),box=(0,0)
|
||||
)
|
||||
return scaled_image
|
@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
201
ldm/models/diffusion/cross_attention_control.py
Normal file
201
ldm/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,201 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
||||
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
||||
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
||||
"""
|
||||
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
if edited_conditioning is not None:
|
||||
assert len(edit_opcodes) == len(edit_options), \
|
||||
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
||||
non_none_edit_options = [x for x in edit_options if x is not None]
|
||||
assert len(non_none_edit_options)>0, "missing edit_options"
|
||||
if len(non_none_edit_options)>1:
|
||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
class Context:
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
cross_attention_control_args: Arguments
|
||||
):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = cross_attention_control_args.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
|
||||
class CrossAttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model, clear_attn_slice=True):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
if clear_attn_slice:
|
||||
m.last_attn_slice = None
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
if dim is None:
|
||||
last_attn_slice = self.last_attn_slice
|
||||
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
else:
|
||||
last_attn_slice = self.last_attn_slice[offset]
|
||||
|
||||
if self.last_attn_slice_mask is None:
|
||||
# just use everything
|
||||
attn_slice = last_attn_slice
|
||||
else:
|
||||
last_attn_slice_mask = self.last_attn_slice_mask
|
||||
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
|
||||
this_attn_slice = attn_slice
|
||||
this_attn_slice_mask = 1 - last_attn_slice_mask
|
||||
attn_slice = this_attn_slice * this_attn_slice_mask + \
|
||||
remapped_last_attn_slice * last_attn_slice_mask
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
else:
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = { offset: attn_slice }
|
||||
else:
|
||||
self.last_attn_slice[offset] = attn_slice
|
||||
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -1,293 +1,48 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(dtype=torch.float32, device=self.device)
|
||||
setattr(self, name, attr)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=True,
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
# This routine gets called from img2img
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
img = x_T
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = (
|
||||
timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
)
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc='DDIM Sampler',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# This routine gets called from ddim_sampling() and decode()
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
def p_sample(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
@ -295,16 +50,17 @@ class DDIMSampler(object):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||
x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index
|
||||
)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
@ -351,83 +107,5 @@ class DDIMSampler(object):
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
return x_prev, pred_x0, None
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
):
|
||||
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
xdec_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
|
||||
if img_callback:
|
||||
img_callback(x_dec, i)
|
||||
|
||||
return x_dec
|
||||
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from omegaconf import ListConfig
|
||||
import urllib
|
||||
|
||||
from ldm.util import (
|
||||
@ -106,7 +107,7 @@ class DDPM(pl.LightningModule):
|
||||
], 'currently only supporting "eps" and "x0"'
|
||||
self.parameterization = parameterization
|
||||
print(
|
||||
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
f' | {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
)
|
||||
self.cond_stage_model = None
|
||||
self.clip_denoised = clip_denoised
|
||||
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
self.use_scheduler = scheduler_config is not None
|
||||
if self.use_scheduler:
|
||||
@ -701,7 +702,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
|
||||
# only for very first batch
|
||||
if (
|
||||
self.scale_by_std
|
||||
@ -820,21 +821,21 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
return self.scale_factor * z
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
def get_learned_conditioning(self, c, **kwargs):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||
self.cond_stage_model.encode
|
||||
):
|
||||
c = self.cond_stage_model.encode(
|
||||
c, embedding_manager=self.embedding_manager
|
||||
c, embedding_manager=self.embedding_manager,**kwargs
|
||||
)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
c = self.cond_stage_model(c)
|
||||
c = self.cond_stage_model(c, **kwargs)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
|
||||
return c
|
||||
|
||||
def meshgrid(self, h, w):
|
||||
@ -1353,7 +1354,7 @@ class LatentDiffusion(DDPM):
|
||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||
rescale_latent = 2 ** (num_downs)
|
||||
|
||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||
tl_patch_coordinates = [
|
||||
(
|
||||
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self,
|
||||
@ -1890,7 +1909,7 @@ class LatentDiffusion(DDPM):
|
||||
N=8,
|
||||
n_row=4,
|
||||
sample=True,
|
||||
ddim_steps=200,
|
||||
ddim_steps=50,
|
||||
ddim_eta=1.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
cond_img = torch.stack(bbox_imgs, dim=0)
|
||||
logs['bbox_image'] = cond_img
|
||||
return logs
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
finetune_keys=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(
|
||||
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||
):
|
||||
# note: restricted to non-trainable encoders currently
|
||||
assert (
|
||||
not self.cond_stage_trainable
|
||||
), "trainable cond stages not yet supported for inpainting"
|
||||
z, c, x, xrec, xc = super().get_input(
|
||||
batch,
|
||||
self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
return_original_cond=True,
|
||||
bs=bs,
|
||||
)
|
||||
|
||||
assert exists(self.concat_keys)
|
||||
c_cat = list()
|
||||
for ck in self.concat_keys:
|
||||
cc = (
|
||||
rearrange(batch[ck], "b h w c -> b c h w")
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
if bs is not None:
|
||||
cc = cc[:bs]
|
||||
cc = cc.to(self.device)
|
||||
bchw = z.shape
|
||||
if ck != self.masked_image_key:
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
if return_first_stage_outputs:
|
||||
return z, all_conds, x, xrec, xc
|
||||
return z, all_conds
|
||||
|
@ -1,39 +1,152 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from torch import nn
|
||||
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
# at this threshold, the scheduler will stop using the Karras
|
||||
# noise schedule and start using the model's schedule
|
||||
STEP_THRESHOLD = 30
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
return result
|
||||
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||
minval = 0.0 + torch.min(result).cpu().numpy()
|
||||
if maxval < threshold and minval > -threshold:
|
||||
return result
|
||||
if maxval > threshold:
|
||||
maxval = min(max(1, scale*maxval), threshold)
|
||||
if minval < -threshold:
|
||||
minval = max(min(-1, scale*minval), -threshold)
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
else:
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
|
||||
class KSampler(object):
|
||||
class KSampler(Sampler):
|
||||
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = K.external.CompVisDenoiser(model)
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
denoiser = K.external.CompVisDenoiser(model)
|
||||
super().__init__(
|
||||
denoiser,
|
||||
schedule,
|
||||
steps=model.num_timesteps,
|
||||
)
|
||||
self.sigmas = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
||||
if self.karras_max is None:
|
||||
self.karras_max = STEP_THRESHOLD
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
outer_model = self.model
|
||||
self.model = outer_model.inner_model
|
||||
super().make_schedule(
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
)
|
||||
self.model = outer_model
|
||||
self.ddim_num_steps = ddim_num_steps
|
||||
# we don't need both of these sigmas, but storing them here to make
|
||||
# comparison easier later on
|
||||
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||
n=ddim_num_steps,
|
||||
sigma_min=self.model.sigmas[0].item(),
|
||||
sigma_max=self.model.sigmas[-1].item(),
|
||||
rho=7.,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# most of these arguments are ignored and are only present for compatibility with
|
||||
if ddim_num_steps >= self.karras_max:
|
||||
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
|
||||
self.sigmas = self.model_sigmas
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||
# in the lstein/k-diffusion branch.
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
**kwargs
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = z_enc,
|
||||
shape = z_enc.shape[1:],
|
||||
conditioning = cond,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask,
|
||||
**kwargs
|
||||
)
|
||||
return samples
|
||||
|
||||
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
return x0
|
||||
|
||||
# Most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
@ -58,31 +171,128 @@ class KSampler(object):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||
|
||||
sigmas = self.model.get_sigmas(S)
|
||||
# if make_schedule() hasn't been called, we do it now
|
||||
if self.sigmas is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
# more randomness to the starting image.
|
||||
if x_T is not None:
|
||||
x = x_T * sigmas[0]
|
||||
if x0 is not None:
|
||||
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
||||
else:
|
||||
x = x_T * sigmas[0]
|
||||
else:
|
||||
x = (
|
||||
torch.randn([batch_size, *shape], device=self.device)
|
||||
* sigmas[0]
|
||||
) # for GPU draw
|
||||
model_wrap_cfg = CFGDenoiser(self.model)
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
return (
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
self.model_wrap = CFGDenoiser(self.model)
|
||||
extra_args = {
|
||||
'cond': cond,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
if self.s_in is None:
|
||||
self.s_in = img.new_ones([img.shape[0]])
|
||||
if self.ds is None:
|
||||
self.ds = []
|
||||
|
||||
# terrible, confusing names here
|
||||
steps = self.ddim_num_steps
|
||||
t_enc = self.t_enc
|
||||
|
||||
# sigmas is a full steps in length, but t_enc might
|
||||
# be less. We start in the middle of the sigma array
|
||||
# and work our way to the end after t_enc steps.
|
||||
# index starts at t_enc and works its way to zero,
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
self.sigmas,
|
||||
s_index,
|
||||
s_in = self.s_in,
|
||||
ds = self.ds,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
return img, None, None
|
||||
|
||||
# REVIEW THIS METHOD: it has never been tested. In particular,
|
||||
# we should not be multiplying by self.sigmas[0] if we
|
||||
# are at an intermediate step in img2img. See similar in
|
||||
# sample() which does work.
|
||||
def get_initial_image(self,x_T,shape,steps):
|
||||
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||
if x_T is not None:
|
||||
return x_T + x
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.inner_model.model.conditioning_key
|
||||
|
||||
|
@ -4,303 +4,49 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device if device else choose_torch_device()
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=True,
|
||||
):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
# print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
img = x_T
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = (
|
||||
timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
)
|
||||
# print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc='PLMS Sampler',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
def p_sample(
|
||||
self,
|
||||
x, # image, called 'img' elsewhere
|
||||
c, # conditioning, called 'cond' elsewhere
|
||||
t, # timesteps, called 'ts' elsewhere
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
@ -309,18 +55,15 @@ class PLMSSampler(object):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(
|
||||
x_in, t_in, c_in
|
||||
).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
|
450
ldm/models/diffusion/sampler.py
Normal file
450
ldm/models/diffusion/sampler.py
Normal file
@ -0,0 +1,450 @@
|
||||
'''
|
||||
ldm.models.diffusion.sampler
|
||||
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
class Sampler(object):
|
||||
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
|
||||
self.model = model
|
||||
self.ddim_timesteps = None
|
||||
self.ddpm_num_timesteps = steps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
# This method was copied over from ddim.py and probably does stuff that is
|
||||
# ddim-specific. Disentangle at some point.
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
self.total_steps = ddim_num_steps
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S, # S is steps
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# check to see if make_schedule() has run, and if not, run it
|
||||
if self.ddim_timesteps is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
ts = self.get_timesteps(S)
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
shape = (batch_size, C, H, W)
|
||||
samples, intermediates = self.do_sampling(
|
||||
conditioning,
|
||||
shape,
|
||||
timesteps=ts,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
steps=S,
|
||||
**kwargs
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def do_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
timesteps=None,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs
|
||||
):
|
||||
b = shape[0]
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
|
||||
total_steps=steps
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc=f'{self.__class__.__name__}',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(b,),
|
||||
step,
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img,i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# NOTE that decode() and sample() are almost the same code, and do the same thing.
|
||||
# The variable names are changed in order to be confusing.
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
ts_next = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||
|
||||
outs = self.p_sample(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=len(self.ddim_timesteps)
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
if img_callback:
|
||||
img_callback(x_dec,i)
|
||||
|
||||
return x_dec
|
||||
|
||||
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||
if x_T is None:
|
||||
return torch.randn(shape, device=self.device)
|
||||
else:
|
||||
return x_T
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
steps=None,
|
||||
):
|
||||
raise NotImplementedError("p_sample() must be implemented in a descendent class")
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
'''
|
||||
Hook that will be called right before the very first invocation of p_sample()
|
||||
to allow subclass to do additional initialization. t_enc corresponds to the actual
|
||||
number of steps that will be run, and may be less than total steps if img2img is
|
||||
active.
|
||||
'''
|
||||
pass
|
||||
|
||||
def get_timesteps(self,ddim_steps):
|
||||
'''
|
||||
The ddim and plms samplers work on timesteps. This method is called after
|
||||
ddim_timesteps are created in make_schedule(), and selects the portion of
|
||||
timesteps that will be used for sampling, depending on the t_enc in img2img.
|
||||
'''
|
||||
return self.ddim_timesteps[:ddim_steps]
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.model.conditioning_key
|
||||
|
||||
def uses_inpainting_model(self)->bool:
|
||||
return self.conditioning_key() in ('hybrid','concat')
|
||||
|
||||
def adjust_settings(self,**kwargs):
|
||||
'''
|
||||
This is a catch-all method for adjusting any instance variables
|
||||
after the sampler is instantiated. No type-checking performed
|
||||
here, so use with care!
|
||||
'''
|
||||
for k in kwargs.keys():
|
||||
try:
|
||||
setattr(self,k,kwargs[k])
|
||||
except AttributeError:
|
||||
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')
|
225
ldm/models/diffusion/shared_invokeai_diffusion.py
Normal file
225
ldm/models/diffusion/shared_invokeai_diffusion.py
Normal file
@ -0,0 +1,225 @@
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback:
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = CrossAttentionControl.Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
||||
)
|
||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||
#todo: apply edit_options using step_count
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
|
||||
return combined_next_x
|
||||
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
||||
both_conditionings).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = dict()
|
||||
for k in conditioning:
|
||||
if isinstance(conditioning[k], list):
|
||||
both_conditionings[k] = [
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
||||
for i in range(len(conditioning[k]))
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
except RuntimeError:
|
||||
# make sure we clean out the attention slices we're storing on the model
|
||||
# TODO don't store things on the model
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
raise
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -1,5 +1,7 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
@ -90,7 +92,7 @@ class LinearAttention(nn.Module):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
return x+h_
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
@ -168,116 +171,114 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
mem_av = psutil.virtual_memory().available / (1024**3)
|
||||
if mem_av > 32:
|
||||
self.einsum_op = self.einsum_op_v1
|
||||
elif mem_av > 12:
|
||||
self.einsum_op = self.einsum_op_v2
|
||||
else:
|
||||
self.einsum_op = self.einsum_op_v3
|
||||
del mem_av
|
||||
else:
|
||||
self.einsum_op = self.einsum_op_v4
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
# mps 64-128 GB
|
||||
def einsum_op_v1(self, q, k, v, r1):
|
||||
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
else:
|
||||
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
|
||||
# needs around half of that slice_size to not generate noise
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
self.attention_slice_wrangler = None
|
||||
|
||||
# mps 16-32 GB (can be optimized)
|
||||
def einsum_op_v2(self, q, k, v, r1):
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
|
||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
self is the current CrossAttention module for which the callback is being invoked.
|
||||
attention_scores are the scores for attention
|
||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
if self.attention_slice_wrangler is not None:
|
||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
# mps 8 GB
|
||||
def einsum_op_v3(self, q, k, v, r1):
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
|
||||
del s2
|
||||
return r1
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
# cuda
|
||||
def einsum_op_v4(self, q, k, v, r1):
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r = self.get_attention_mem_efficient(q, k, v)
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@ -297,9 +298,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
x += self.attn1(self.norm1(x.clone()))
|
||||
x += self.attn2(self.norm2(x.clone()), context=context)
|
||||
x += self.ff(self.norm3(x.clone()))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@ -53,9 +49,15 @@ class Upsample(nn.Module):
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
cpu_m1_cond = True if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and \
|
||||
x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False
|
||||
if cpu_m1_cond:
|
||||
x = x.to('cpu') # send to cpu
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if cpu_m1_cond:
|
||||
x = x.to('mps') # return to mps
|
||||
return x
|
||||
|
||||
|
||||
@ -121,30 +123,25 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h1 = x
|
||||
h2 = self.norm1(h1)
|
||||
del h1
|
||||
|
||||
h3 = nonlinearity(h2)
|
||||
del h2
|
||||
|
||||
h4 = self.conv1(h3)
|
||||
del h3
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
x_size = x.size()
|
||||
if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0:
|
||||
self.to('cpu')
|
||||
x = x.to('cpu')
|
||||
else:
|
||||
self.to('mps')
|
||||
x = x.to('mps')
|
||||
h = self.norm1(x)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(silu(temb))[:,:,None,None]
|
||||
|
||||
h5 = self.norm2(h4)
|
||||
del h4
|
||||
|
||||
h6 = nonlinearity(h5)
|
||||
del h5
|
||||
|
||||
h7 = self.dropout(h6)
|
||||
del h6
|
||||
|
||||
h8 = self.conv2(h7)
|
||||
del h7
|
||||
h = self.norm2(h)
|
||||
h = silu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
@ -152,7 +149,7 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h8
|
||||
return x + h
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
@ -209,8 +206,7 @@ class AttnBlock(nn.Module):
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
|
||||
if device_type == 'cuda':
|
||||
if q.device.type == 'cuda':
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -263,7 +259,7 @@ class AttnBlock(nn.Module):
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(f" | Making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
@ -382,7 +378,7 @@ class Model(nn.Module):
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = silu(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
@ -416,7 +412,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -513,7 +509,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -539,7 +535,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
print(" | Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
@ -599,22 +595,16 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h1 = self.conv_in(z)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h2 = self.mid.block_1(h1, temb)
|
||||
del h1
|
||||
|
||||
h3 = self.mid.attn_1(h2)
|
||||
del h2
|
||||
|
||||
h = self.mid.block_2(h3, temb)
|
||||
del h3
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# prepare for up sampling
|
||||
device_type = 'mps' if h.device.type == 'mps' else 'cuda'
|
||||
gc.collect()
|
||||
if device_type == 'cuda':
|
||||
if h.device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# upsampling
|
||||
@ -622,33 +612,19 @@ class Decoder(nn.Module):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
t = h
|
||||
h = self.up[i_level].attn[i_block](t)
|
||||
del t
|
||||
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
t = h
|
||||
h = self.up[i_level].upsample(t)
|
||||
del t
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h1 = self.norm_out(h)
|
||||
del h
|
||||
|
||||
h2 = nonlinearity(h1)
|
||||
del h1
|
||||
|
||||
h = self.conv_out(h2)
|
||||
del h2
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
t = h
|
||||
h = torch.tanh(t)
|
||||
del t
|
||||
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
@ -683,7 +659,7 @@ class SimpleDecoder(nn.Module):
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
@ -731,7 +707,7 @@ class UpsampleDecoder(nn.Module):
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -907,7 +883,7 @@ class FirstStagePostProcessor(nn.Module):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
z = silu(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
|
@ -64,7 +64,11 @@ def make_ddim_timesteps(
|
||||
):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
if c < 1:
|
||||
c = 1
|
||||
|
||||
# remove 1 final step to prevent index out of bound error
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))[:-1]
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = (
|
||||
(
|
||||
@ -81,8 +85,7 @@ def make_ddim_timesteps(
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
# steps_out = ddim_timesteps + 1
|
||||
steps_out = ddim_timesteps
|
||||
steps_out = ddim_timesteps + 1
|
||||
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
@ -252,12 +255,6 @@ def normalization(channels):
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
@ -82,7 +82,9 @@ class EmbeddingManager(nn.Module):
|
||||
get_embedding_for_clip_token,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
token_dim = 1280
|
||||
# per bug report #572
|
||||
#token_dim = 1280
|
||||
token_dim = 768
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_for_string = partial(
|
||||
@ -167,9 +169,14 @@ class EmbeddingManager(nn.Module):
|
||||
placeholder_embedding.shape[0], max_step_tokens
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
else:
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
continue
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -5,7 +7,7 @@ import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
@ -454,6 +456,223 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
if self.fragment_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
|
||||
fragment_weights = kwargs[self.fragment_weights_key]
|
||||
# self.transformer doesn't like receiving "fragment_weights" as an argument
|
||||
kwargs.pop(self.fragment_weights_key)
|
||||
|
||||
should_return_tokens = False
|
||||
if self.return_tokens_key in kwargs:
|
||||
should_return_tokens = kwargs.get(self.return_tokens_key, False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop(self.return_tokens_key)
|
||||
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
tokens = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
if include_start_and_end_markers:
|
||||
return tokens
|
||||
else:
|
||||
return [x[1:-1] for x in tokens]
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=True,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
per_token_weights = per_token_weights[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
else:
|
||||
original_mean = z.mean()
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
|
25
ldm/util.py
25
ldm/util.py
@ -2,6 +2,7 @@ import importlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from collections import abc
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
@ -74,7 +75,7 @@ def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
@ -212,3 +213,25 @@ def parallel_data_prefetch(
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
|
||||
|
||||
rand_val = torch.rand(res[0]+1, res[1]+1)
|
||||
|
||||
angles = 2*math.pi*rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
|
||||
|
||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||
|
||||
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
|
Reference in New Issue
Block a user