fix off-by-one bug in cross-attention-control (#1774)

prompt token sequences begin with a "beginning-of-sequence" marker <bos> and end with a repeated "end-of-sequence" marker <eos> - to make a default prompt length of <bos> + 75 prompt tokens + <eos>. the .swap() code was failing to take the column for <bos> at index 0 into account. the changes here do that, and also add extra handling for a single <eos> (which may be redundant but which is included for completeness).

based on my understanding and some assumptions about how this all works, the reason .swap() nevertheless seemed to do the right thing, to some extent, is because over multiple steps the conditioning process in Stable Diffusion operates as a feedback loop. a change to token n-1 has flow-on effects to how the [1x4x64x64] latent tensor is modified by all the tokens after it, - and as the next step is processed, all the tokens before it as well. intuitively, a token's conditioning effects "echo" throughout the whole length of the prompt. so even though the token at n-1 was being edited when what the user actually wanted was to edit the token at n, it nevertheless still had some non-negligible effect, in roughly the right direction, often enough that it seemed like it was working properly.
This commit is contained in:
Damian Stewart 2022-12-04 11:41:03 +01:00 committed by Damian Stewart
parent f3570d8344
commit c6f31e5f36

View File

@ -77,8 +77,13 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
@ -105,6 +110,12 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,