mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
runs but doesn't work properly - see below for test prompt
test prompt: "a cat sitting on a car {a dog sitting on a car}" -W 384 -H 256 -s 10 -S 12346 -A k_euler note that substition of dog for cat is currently hard-coded (ksampler.py line 43-44)
This commit is contained in:
parent
33d6603fef
commit
8ff507b03b
177
c_a_c.py
177
c_a_c.py
@ -1,177 +0,0 @@
|
|||||||
# Functions supporting Cross-Attention Control
|
|
||||||
# Copied from https://github.com/bloc97/CrossAttentionControl
|
|
||||||
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_token(prompt, index, clip_tokenizer):
|
|
||||||
tokens = clip_tokenizer(prompt,
|
|
||||||
padding='max_length',
|
|
||||||
max_length=clip_tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors='pt',
|
|
||||||
return_overflowing_tokens=True
|
|
||||||
).input_ids[0]
|
|
||||||
return clip_tokenizer.decode(tokens[index:index+1])
|
|
||||||
|
|
||||||
|
|
||||||
def init_attention_weights(weight_tuples, clip_tokenizer, unet, device):
|
|
||||||
tokens_length = clip_tokenizer.model_max_length
|
|
||||||
weights = torch.ones(tokens_length)
|
|
||||||
|
|
||||||
for i, w in weight_tuples:
|
|
||||||
if i < tokens_length and i >= 0:
|
|
||||||
weights[i] = w
|
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
||||||
module.last_attn_slice_weights = weights.to(device)
|
|
||||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
||||||
module.last_attn_slice_weights = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device):
|
|
||||||
tokens_length = clip_tokenizer.model_max_length
|
|
||||||
mask = torch.zeros(tokens_length)
|
|
||||||
indices_target = torch.arange(tokens_length, dtype=torch.long)
|
|
||||||
indices = torch.zeros(tokens_length, dtype=torch.long)
|
|
||||||
|
|
||||||
tokens = tokens.input_ids.numpy()[0]
|
|
||||||
tokens_edit = tokens_edit.input_ids.numpy()[0]
|
|
||||||
|
|
||||||
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
|
|
||||||
if b0 < tokens_length:
|
|
||||||
if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0):
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
||||||
module.last_attn_slice_mask = mask.to(device)
|
|
||||||
module.last_attn_slice_indices = indices.to(device)
|
|
||||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
||||||
module.last_attn_slice_mask = None
|
|
||||||
module.last_attn_slice_indices = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_attention_func(unet):
|
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
||||||
def new_attention(self, query, key, value):
|
|
||||||
# TODO: use baddbmm for better performance
|
|
||||||
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
|
||||||
attn_slice = attention_scores.softmax(dim=-1)
|
|
||||||
# compute attention output
|
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
|
||||||
if self.last_attn_slice_mask is not None:
|
|
||||||
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
|
|
||||||
self.last_attn_slice_indices))
|
|
||||||
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
|
||||||
+ new_attn_slice * self.last_attn_slice_mask)
|
|
||||||
else:
|
|
||||||
attn_slice = self.last_attn_slice
|
|
||||||
|
|
||||||
self.use_last_attn_slice = False
|
|
||||||
|
|
||||||
if self.save_last_attn_slice:
|
|
||||||
self.last_attn_slice = attn_slice
|
|
||||||
self.save_last_attn_slice = False
|
|
||||||
|
|
||||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
||||||
attn_slice = attn_slice * self.last_attn_slice_weights
|
|
||||||
self.use_last_attn_weights = False
|
|
||||||
|
|
||||||
hidden_states = torch.matmul(attn_slice, value)
|
|
||||||
# reshape hidden_states
|
|
||||||
return self.reshape_batch_dim_to_heads(hidden_states)
|
|
||||||
|
|
||||||
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
|
||||||
|
|
||||||
batch_size_attention = query.shape[0]
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(batch_size_attention, sequence_length, dim // self.heads),
|
|
||||||
device=query.device, dtype=query.dtype
|
|
||||||
)
|
|
||||||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
|
||||||
for i in range(hidden_states.shape[0] // slice_size):
|
|
||||||
start_idx = i * slice_size
|
|
||||||
end_idx = (i + 1) * slice_size
|
|
||||||
attn_slice = (
|
|
||||||
torch.matmul(query[start_idx:end_idx],
|
|
||||||
key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
|
||||||
) # TODO: use baddbmm for better performance
|
|
||||||
attn_slice = attn_slice.softmax(dim=-1)
|
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
|
||||||
if self.last_attn_slice_mask is not None:
|
|
||||||
new_attn_slice = (torch.index_select(self.last_attn_slice,
|
|
||||||
-1, self.last_attn_slice_indices))
|
|
||||||
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
|
||||||
+ new_attn_slice * self.last_attn_slice_mask)
|
|
||||||
else:
|
|
||||||
attn_slice = self.last_attn_slice
|
|
||||||
|
|
||||||
self.use_last_attn_slice = False
|
|
||||||
|
|
||||||
if self.save_last_attn_slice:
|
|
||||||
self.last_attn_slice = attn_slice
|
|
||||||
self.save_last_attn_slice = False
|
|
||||||
|
|
||||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
||||||
attn_slice = attn_slice * self.last_attn_slice_weights
|
|
||||||
self.use_last_attn_weights = False
|
|
||||||
|
|
||||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
|
|
||||||
return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states
|
|
||||||
|
|
||||||
for _, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention':
|
|
||||||
module.last_attn_slice = None
|
|
||||||
module.use_last_attn_slice = False
|
|
||||||
module.use_last_attn_weights = False
|
|
||||||
module.save_last_attn_slice = False
|
|
||||||
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
|
|
||||||
module._attention = new_attention.__get__(module, type(module))
|
|
||||||
|
|
||||||
|
|
||||||
def use_last_tokens_attention(unet, use=True):
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
||||||
module.use_last_attn_slice = use
|
|
||||||
|
|
||||||
|
|
||||||
def use_last_tokens_attention_weights(unet, use=True):
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
||||||
module.use_last_attn_weights = use
|
|
||||||
|
|
||||||
|
|
||||||
def use_last_self_attention(unet, use=True):
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
||||||
module.use_last_attn_slice = use
|
|
||||||
|
|
||||||
|
|
||||||
def save_last_tokens_attention(unet, save=True):
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
||||||
module.save_last_attn_slice = save
|
|
||||||
|
|
||||||
|
|
||||||
def save_last_self_attention(unet, save=True):
|
|
||||||
for name, module in unet.named_modules():
|
|
||||||
module_name = type(module).__name__
|
|
||||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
||||||
module.save_last_attn_slice = save
|
|
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +9,7 @@ from PIL import Image
|
|||||||
from torch import autocast
|
from torch import autocast
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import c_a_c
|
import .ldm.models.diffusion.cross_attention
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -41,7 +42,7 @@ def stablediffusion(
|
|||||||
|
|
||||||
# If seed is None, randomly select seed from 0 to 2^32-1
|
# If seed is None, randomly select seed from 0 to 2^32-1
|
||||||
if seed is None: seed = random.randrange(2**32 - 1)
|
if seed is None: seed = random.randrange(2**32 - 1)
|
||||||
generator = torch.cuda.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
|
|
||||||
# Set inference timesteps to scheduler
|
# Set inference timesteps to scheduler
|
||||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085,
|
scheduler = LMSDiscreteScheduler(beta_start=0.00085,
|
||||||
|
@ -32,7 +32,7 @@ from ldm.invoke.pngwriter import PngWriter
|
|||||||
from ldm.invoke.args import metadata_from_png
|
from ldm.invoke.args import metadata_from_png
|
||||||
from ldm.invoke.image_util import InitImageResizer
|
from ldm.invoke.image_util import InitImageResizer
|
||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.conditioning import get_uc_and_c
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
from ldm.invoke.model_cache import ModelCache
|
from ldm.invoke.model_cache import ModelCache
|
||||||
from ldm.invoke.seamless import configure_model_padding
|
from ldm.invoke.seamless import configure_model_padding
|
||||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||||
@ -400,7 +400,7 @@ class Generate:
|
|||||||
mask_image = None
|
mask_image = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uc, c = get_uc_and_c(
|
uc, c, ec = get_uc_and_c_and_ec(
|
||||||
prompt, model =self.model,
|
prompt, model =self.model,
|
||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
@ -438,7 +438,7 @@ class Generate:
|
|||||||
sampler=self.sampler,
|
sampler=self.sampler,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
conditioning=(uc, c),
|
conditioning=(uc, c, ec),
|
||||||
ddim_eta=ddim_eta,
|
ddim_eta=ddim_eta,
|
||||||
image_callback=image_callback, # called after the final image is generated
|
image_callback=image_callback, # called after the final image is generated
|
||||||
step_callback=step_callback, # called after each intermediate image is generated
|
step_callback=step_callback, # called after each intermediate image is generated
|
||||||
@ -469,14 +469,14 @@ class Generate:
|
|||||||
save_original = save_original,
|
save_original = save_original,
|
||||||
image_callback = image_callback)
|
image_callback = image_callback)
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print('>> Could not generate image.')
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
print('**Interrupted** Partial results will be returned.')
|
print('**Interrupted** Partial results will be returned.')
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
|
except (RuntimeError, Exception) as e:
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print('>> Could not generate image.')
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print('>> Usage stats:')
|
print('>> Usage stats:')
|
||||||
|
@ -12,7 +12,7 @@ log_tokenization() print out colour-coded tokens and warn if trunca
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False):
|
||||||
# Extract Unconditioned Words From Prompt
|
# Extract Unconditioned Words From Prompt
|
||||||
unconditioned_words = ''
|
unconditioned_words = ''
|
||||||
unconditional_regex = r'\[(.*?)\]'
|
unconditional_regex = r'\[(.*?)\]'
|
||||||
@ -26,7 +26,19 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
|||||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
||||||
prompt = re.sub(' +', ' ', clean_prompt)
|
prompt = re.sub(' +', ' ', clean_prompt)
|
||||||
|
|
||||||
|
edited_words = None
|
||||||
|
edited_regex = r'\{(.*?)\}'
|
||||||
|
edited = re.findall(edited_regex, prompt)
|
||||||
|
if len(edited) > 0:
|
||||||
|
edited_words = ' '.join(edited)
|
||||||
|
edited_regex_compile = re.compile(edited_regex)
|
||||||
|
clean_prompt = edited_regex_compile.sub(' ', prompt)
|
||||||
|
prompt = re.sub(' +', ' ', clean_prompt)
|
||||||
|
|
||||||
uc = model.get_learned_conditioning([unconditioned_words])
|
uc = model.get_learned_conditioning([unconditioned_words])
|
||||||
|
ec = None
|
||||||
|
if edited_words is not None:
|
||||||
|
ec = model.get_learned_conditioning([edited_words])
|
||||||
|
|
||||||
# get weighted sub-prompts
|
# get weighted sub-prompts
|
||||||
weighted_subprompts = split_weighted_subprompts(
|
weighted_subprompts = split_weighted_subprompts(
|
||||||
@ -48,7 +60,7 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
|||||||
log_tokenization(prompt, model, log_tokens, 1)
|
log_tokenization(prompt, model, log_tokens, 1)
|
||||||
c = model.get_learned_conditioning([prompt])
|
c = model.get_learned_conditioning([prompt])
|
||||||
uc = model.get_learned_conditioning([unconditioned_words])
|
uc = model.get_learned_conditioning([unconditioned_words])
|
||||||
return (uc, c)
|
return (uc, c, ec)
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||||
"""
|
"""
|
||||||
|
@ -19,7 +19,7 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c = conditioning
|
uc, c, ec = conditioning
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -43,6 +43,7 @@ class Txt2Img(Generator):
|
|||||||
verbose = False,
|
verbose = False,
|
||||||
unconditional_guidance_scale = cfg_scale,
|
unconditional_guidance_scale = cfg_scale,
|
||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
|
edited_conditioning = ec,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention:
|
class CrossAttentionControl:
|
||||||
|
|
||||||
class AttentionType(Enum):
|
class AttentionType(Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@ -15,5 +15,146 @@ class CrossAttention:
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inject_attention_mask_capture(cls, model, callback):
|
def setup_attention_editing(cls, model, original_tokens_length: int,
|
||||||
pass
|
substitute_conditioning: torch.Tensor = None,
|
||||||
|
token_indices_to_edit: list = None):
|
||||||
|
|
||||||
|
# adapted from init_attention_edit
|
||||||
|
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
|
||||||
|
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
|
||||||
|
|
||||||
|
if substitute_conditioning is not None:
|
||||||
|
|
||||||
|
device = substitute_conditioning.device
|
||||||
|
|
||||||
|
# this is not very torch-y
|
||||||
|
mask = torch.zeros(original_tokens_length)
|
||||||
|
for i in token_indices_to_edit:
|
||||||
|
mask[i] = 1
|
||||||
|
|
||||||
|
self_attention_module.last_attn_slice_mask = None
|
||||||
|
self_attention_module.last_attn_slice_indices = None
|
||||||
|
tokens_attention_module.last_attn_slice_mask = mask.to(device)
|
||||||
|
tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device)
|
||||||
|
|
||||||
|
cls.inject_attention_functions(model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_save_attention_maps(cls, model):
|
||||||
|
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
|
||||||
|
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
|
||||||
|
self_attention_module.save_last_attn_slice = True
|
||||||
|
tokens_attention_module.save_last_attn_slice = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_apply_saved_attention_maps(cls, model):
|
||||||
|
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
|
||||||
|
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
|
||||||
|
self_attention_module.use_last_attn_slice = True
|
||||||
|
tokens_attention_module.use_last_attn_slice = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inject_attention_functions(cls, unet):
|
||||||
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
def new_attention(self, query, key, value):
|
||||||
|
# TODO: use baddbmm for better performance
|
||||||
|
print(f"entered new_attention")
|
||||||
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||||
|
attn_slice = attention_scores.softmax(dim=-1)
|
||||||
|
# compute attention output
|
||||||
|
|
||||||
|
if self.use_last_attn_slice:
|
||||||
|
if self.last_attn_slice_mask is not None:
|
||||||
|
print('using masked last_attn_slice')
|
||||||
|
|
||||||
|
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
|
||||||
|
self.last_attn_slice_indices))
|
||||||
|
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
||||||
|
+ new_attn_slice * self.last_attn_slice_mask)
|
||||||
|
else:
|
||||||
|
print('using unmasked last_attn_slice')
|
||||||
|
attn_slice = self.last_attn_slice
|
||||||
|
|
||||||
|
self.use_last_attn_slice = False
|
||||||
|
else:
|
||||||
|
print('not using last_attn_slice')
|
||||||
|
|
||||||
|
if self.save_last_attn_slice:
|
||||||
|
print('saving last_attn_slice')
|
||||||
|
self.last_attn_slice = attn_slice
|
||||||
|
self.save_last_attn_slice = False
|
||||||
|
else:
|
||||||
|
print('not saving last_attn_slice')
|
||||||
|
|
||||||
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||||
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||||
|
self.use_last_attn_weights = False
|
||||||
|
|
||||||
|
hidden_states = torch.matmul(attn_slice, value)
|
||||||
|
# reshape hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
for _, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention':
|
||||||
|
module.last_attn_slice = None
|
||||||
|
module.use_last_attn_slice = False
|
||||||
|
module.use_last_attn_weights = False
|
||||||
|
module.save_last_attn_slice = False
|
||||||
|
module.cross_attention_callback = new_attention.__get__(module, type(module))
|
||||||
|
|
||||||
|
|
||||||
|
# original code below
|
||||||
|
|
||||||
|
# Functions supporting Cross-Attention Control
|
||||||
|
# Copied from https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_token(prompt, index, clip_tokenizer):
|
||||||
|
tokens = clip_tokenizer(prompt,
|
||||||
|
padding='max_length',
|
||||||
|
max_length=clip_tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
return_overflowing_tokens=True
|
||||||
|
).input_ids[0]
|
||||||
|
return clip_tokenizer.decode(tokens[index:index + 1])
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_tokens_attention(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.use_last_attn_slice = use
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_tokens_attention_weights(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.use_last_attn_weights = use
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_self_attention(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.use_last_attn_slice = use
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_tokens_attention(unet, save=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.save_last_attn_slice = save
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_self_attention(unet, save=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.save_last_attn_slice = save
|
||||||
|
@ -13,6 +13,7 @@ from ldm.modules.diffusionmodules.util import (
|
|||||||
noise_like,
|
noise_like,
|
||||||
extract_into_tensor,
|
extract_into_tensor,
|
||||||
)
|
)
|
||||||
|
from ldm.models.diffusion.cross_attention import CrossAttentionControl
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
@ -29,21 +30,41 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0):
|
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.warmup_max = warmup
|
self.warmup_max = warmup
|
||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
|
|
||||||
|
self.edited_conditioning = edited_conditioning
|
||||||
|
|
||||||
|
if self.edited_conditioning is not None:
|
||||||
|
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>'
|
||||||
|
token_indices_to_edit = [2] # 'cat'
|
||||||
|
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit)
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = torch.cat([uncond, cond])
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
|
|
||||||
module = self.get_attention_module(AttentionLayer.TOKENS)
|
print('generating new unconditioned latents')
|
||||||
|
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
||||||
|
|
||||||
|
# process x using the original prompt, saving the attention maps if required
|
||||||
|
if self.edited_conditioning is not None:
|
||||||
|
# this is automatically toggled off after the model forward()
|
||||||
|
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
||||||
|
print('generating new conditioned latents')
|
||||||
|
conditioned_latents = self.inner_model(x, sigma, cond=cond)
|
||||||
|
|
||||||
|
if self.edited_conditioning is not None:
|
||||||
|
# process x again, using the saved attention maps but the new conditioning
|
||||||
|
# this is automatically toggled off after the model forward()
|
||||||
|
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
||||||
|
print('generating edited conditioned latents')
|
||||||
|
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||||
|
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
@ -52,7 +73,8 @@ class CFGDenoiser(nn.Module):
|
|||||||
thresh = self.threshold
|
thresh = self.threshold
|
||||||
if thresh > self.threshold:
|
if thresh > self.threshold:
|
||||||
thresh = self.threshold
|
thresh = self.threshold
|
||||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
delta = (conditioned_latents - unconditioned_latents)
|
||||||
|
return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh)
|
||||||
|
|
||||||
|
|
||||||
class KSampler(Sampler):
|
class KSampler(Sampler):
|
||||||
@ -169,6 +191,7 @@ class KSampler(Sampler):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
|
edited_conditioning=None,
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -200,7 +223,7 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning)
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
|
@ -170,6 +170,8 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
|
self.cross_attention_callback = None
|
||||||
|
|
||||||
def einsum_op_compvis(self, q, k, v):
|
def einsum_op_compvis(self, q, k, v):
|
||||||
s = einsum('b i d, b j d -> b i j', q, k)
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||||
@ -244,8 +246,13 @@ class CrossAttention(nn.Module):
|
|||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
if self.cross_attention_callback is not None:
|
||||||
|
r = self.cross_attention_callback(q, k, v)
|
||||||
|
else:
|
||||||
r = self.einsum_op(q, k, v)
|
r = self.einsum_op(q, k, v)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user