mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cross-attention working with placeholder {} syntax
This commit is contained in:
@ -30,7 +30,7 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None):
|
||||
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
@ -39,24 +39,23 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
self.edited_conditioning = edited_conditioning
|
||||
|
||||
if self.edited_conditioning is not None:
|
||||
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>'
|
||||
token_indices_to_edit = [2] # 'cat'
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit)
|
||||
if edited_conditioning is not None:
|
||||
# <start> a cat sitting on a car <end>
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
# pass through the attention func but don't act on it
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model)
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
|
||||
print('generating new unconditioned latents')
|
||||
print('generating unconditioned latents')
|
||||
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
||||
|
||||
# process x using the original prompt, saving the attention maps if required
|
||||
if self.edited_conditioning is not None:
|
||||
# this is automatically toggled off after the model forward()
|
||||
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
||||
print('generating new conditioned latents')
|
||||
print('generating conditioned latents')
|
||||
conditioned_latents = self.inner_model(x, sigma, cond=cond)
|
||||
|
||||
if self.edited_conditioning is not None:
|
||||
@ -192,6 +191,7 @@ class KSampler(Sampler):
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
edit_token_index_map=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -223,7 +223,8 @@ class KSampler(Sampler):
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning)
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10),
|
||||
edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
|
Reference in New Issue
Block a user