cross-attention working with placeholder {} syntax

This commit is contained in:
Damian at mba
2022-10-17 21:15:03 +02:00
parent 8ff507b03b
commit 1fc1f8bf05
8 changed files with 534 additions and 237 deletions

View File

@ -30,7 +30,7 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None):
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None):
super().__init__()
self.inner_model = model
self.threshold = threshold
@ -39,24 +39,23 @@ class CFGDenoiser(nn.Module):
self.edited_conditioning = edited_conditioning
if self.edited_conditioning is not None:
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>'
token_indices_to_edit = [2] # 'cat'
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit)
if edited_conditioning is not None:
# <start> a cat sitting on a car <end>
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
else:
# pass through the attention func but don't act on it
CrossAttentionControl.setup_attention_editing(self.inner_model)
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
print('generating new unconditioned latents')
print('generating unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
# process x using the original prompt, saving the attention maps if required
if self.edited_conditioning is not None:
# this is automatically toggled off after the model forward()
CrossAttentionControl.request_save_attention_maps(self.inner_model)
print('generating new conditioned latents')
print('generating conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=cond)
if self.edited_conditioning is not None:
@ -192,6 +191,7 @@ class KSampler(Sampler):
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
edit_token_index_map=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -223,7 +223,8 @@ class KSampler(Sampler):
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning)
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10),
edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,