runs but doesn't work properly - see below for test prompt

test prompt:
"a cat sitting on a car {a dog sitting on a car}" -W 384 -H 256 -s 10 -S 12346 -A k_euler
note that substition of dog for cat is currently hard-coded (ksampler.py
	line 43-44)
This commit is contained in:
Damian at mba 2022-10-16 20:39:47 +02:00
parent 33d6603fef
commit 8ff507b03b
8 changed files with 207 additions and 199 deletions

177
c_a_c.py
View File

@ -1,177 +0,0 @@
# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token(prompt, index, clip_tokenizer):
tokens = clip_tokenizer(prompt,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
return_overflowing_tokens=True
).input_ids[0]
return clip_tokenizer.decode(tokens[index:index+1])
def init_attention_weights(weight_tuples, clip_tokenizer, unet, device):
tokens_length = clip_tokenizer.model_max_length
weights = torch.ones(tokens_length)
for i, w in weight_tuples:
if i < tokens_length and i >= 0:
weights[i] = w
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.last_attn_slice_weights = weights.to(device)
if module_name == 'CrossAttention' and 'attn1' in name:
module.last_attn_slice_weights = None
def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device):
tokens_length = clip_tokenizer.model_max_length
mask = torch.zeros(tokens_length)
indices_target = torch.arange(tokens_length, dtype=torch.long)
indices = torch.zeros(tokens_length, dtype=torch.long)
tokens = tokens.input_ids.numpy()[0]
tokens_edit = tokens_edit.input_ids.numpy()[0]
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
if b0 < tokens_length:
if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0):
mask[b0:b1] = 1
indices[b0:b1] = indices_target[a0:a1]
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.last_attn_slice_mask = mask.to(device)
module.last_attn_slice_indices = indices.to(device)
if module_name == 'CrossAttention' and 'attn1' in name:
module.last_attn_slice_mask = None
module.last_attn_slice_indices = None
def init_attention_func(unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def new_attention(self, query, key, value):
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attn_slice = attention_scores.softmax(dim=-1)
# compute attention output
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
self.last_attn_slice_indices))
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
hidden_states = torch.matmul(attn_slice, value)
# reshape hidden_states
return self.reshape_batch_dim_to_heads(hidden_states)
def new_sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads),
device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx],
key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = (torch.index_select(self.last_attn_slice,
-1, self.last_attn_slice_indices))
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states
for _, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention':
module.last_attn_slice = None
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.save_last_attn_slice = False
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
module._attention = new_attention.__get__(module, type(module))
def use_last_tokens_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_weights = use
def use_last_self_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.save_last_attn_slice = save
def save_last_self_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.save_last_attn_slice = save

View File

@ -1,4 +1,5 @@
import random
import traceback
import numpy as np
import torch
@ -8,7 +9,7 @@ from PIL import Image
from torch import autocast
from tqdm.auto import tqdm
import c_a_c
import .ldm.models.diffusion.cross_attention
@torch.no_grad()
@ -41,7 +42,7 @@ def stablediffusion(
# If seed is None, randomly select seed from 0 to 2^32-1
if seed is None: seed = random.randrange(2**32 - 1)
generator = torch.cuda.manual_seed(seed)
generator = torch.manual_seed(seed)
# Set inference timesteps to scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085,

View File

@ -32,7 +32,7 @@ from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
@ -400,7 +400,7 @@ class Generate:
mask_image = None
try:
uc, c = get_uc_and_c(
uc, c, ec = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
log_tokens =self.log_tokenization
@ -438,7 +438,7 @@ class Generate:
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
conditioning=(uc, c, ec),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
@ -469,14 +469,14 @@ class Generate:
save_original = save_original,
image_callback = image_callback)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
except (RuntimeError, Exception) as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time()
print('>> Usage stats:')

View File

@ -12,7 +12,7 @@ log_tokenization() print out colour-coded tokens and warn if trunca
import re
import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
@ -26,7 +26,19 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
edited_words = None
edited_regex = r'\{(.*?)\}'
edited = re.findall(edited_regex, prompt)
if len(edited) > 0:
edited_words = ' '.join(edited)
edited_regex_compile = re.compile(edited_regex)
clean_prompt = edited_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
uc = model.get_learned_conditioning([unconditioned_words])
ec = None
if edited_words is not None:
ec = model.get_learned_conditioning([edited_words])
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
@ -48,7 +60,7 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c)
return (uc, c, ec)
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""

View File

@ -19,7 +19,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height'
"""
self.perlin = perlin
uc, c = conditioning
uc, c, ec = conditioning
@torch.no_grad()
def make_image(x_T):
@ -43,6 +43,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
edited_conditioning = ec,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,

View File

@ -1,8 +1,8 @@
from enum import Enum
import torch
class CrossAttention:
class CrossAttentionControl:
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@ -15,5 +15,146 @@ class CrossAttention:
return module
@classmethod
def inject_attention_mask_capture(cls, model, callback):
pass
def setup_attention_editing(cls, model, original_tokens_length: int,
substitute_conditioning: torch.Tensor = None,
token_indices_to_edit: list = None):
# adapted from init_attention_edit
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
if substitute_conditioning is not None:
device = substitute_conditioning.device
# this is not very torch-y
mask = torch.zeros(original_tokens_length)
for i in token_indices_to_edit:
mask[i] = 1
self_attention_module.last_attn_slice_mask = None
self_attention_module.last_attn_slice_indices = None
tokens_attention_module.last_attn_slice_mask = mask.to(device)
tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device)
cls.inject_attention_functions(model)
@classmethod
def request_save_attention_maps(cls, model):
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
self_attention_module.save_last_attn_slice = True
tokens_attention_module.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model):
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
self_attention_module.use_last_attn_slice = True
tokens_attention_module.use_last_attn_slice = True
@classmethod
def inject_attention_functions(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def new_attention(self, query, key, value):
# TODO: use baddbmm for better performance
print(f"entered new_attention")
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attn_slice = attention_scores.softmax(dim=-1)
# compute attention output
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
print('using masked last_attn_slice')
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
self.last_attn_slice_indices))
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else:
print('using unmasked last_attn_slice')
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
else:
print('not using last_attn_slice')
if self.save_last_attn_slice:
print('saving last_attn_slice')
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
else:
print('not saving last_attn_slice')
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
hidden_states = torch.matmul(attn_slice, value)
# reshape hidden_states
return hidden_states
for _, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention':
module.last_attn_slice = None
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.save_last_attn_slice = False
module.cross_attention_callback = new_attention.__get__(module, type(module))
# original code below
# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token(prompt, index, clip_tokenizer):
tokens = clip_tokenizer(prompt,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
return_overflowing_tokens=True
).input_ids[0]
return clip_tokenizer.decode(tokens[index:index + 1])
def use_last_tokens_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_weights = use
def use_last_self_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.save_last_attn_slice = save
def save_last_self_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.save_last_attn_slice = save

View File

@ -13,6 +13,7 @@ from ldm.modules.diffusionmodules.util import (
noise_like,
extract_into_tensor,
)
from ldm.models.diffusion.cross_attention import CrossAttentionControl
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
@ -29,21 +30,41 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None):
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.edited_conditioning = edited_conditioning
if self.edited_conditioning is not None:
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>'
token_indices_to_edit = [2] # 'cat'
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit)
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
module = self.get_attention_module(AttentionLayer.TOKENS)
print('generating new unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
# process x using the original prompt, saving the attention maps if required
if self.edited_conditioning is not None:
# this is automatically toggled off after the model forward()
CrossAttentionControl.request_save_attention_maps(self.inner_model)
print('generating new conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=cond)
if self.edited_conditioning is not None:
# process x again, using the saved attention maps but the new conditioning
# this is automatically toggled off after the model forward()
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
print('generating edited conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -52,7 +73,8 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
delta = (conditioned_latents - unconditioned_latents)
return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh)
class KSampler(Sampler):
@ -169,6 +191,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -200,7 +223,7 @@ class KSampler(Sampler):
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,

View File

@ -170,6 +170,8 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.cross_attention_callback = None
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
@ -244,8 +246,13 @@ class CrossAttention(nn.Module):
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
if self.cross_attention_callback is not None:
r = self.cross_attention_callback(q, k, v)
else:
r = self.einsum_op(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module):