From c6f31e5f3603a552d0d29a845ed2dec32e57aee6 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 4 Dec 2022 11:41:03 +0100 Subject: [PATCH] fix off-by-one bug in cross-attention-control (#1774) prompt token sequences begin with a "beginning-of-sequence" marker and end with a repeated "end-of-sequence" marker - to make a default prompt length of + 75 prompt tokens + . the .swap() code was failing to take the column for at index 0 into account. the changes here do that, and also add extra handling for a single (which may be redundant but which is included for completeness). based on my understanding and some assumptions about how this all works, the reason .swap() nevertheless seemed to do the right thing, to some extent, is because over multiple steps the conditioning process in Stable Diffusion operates as a feedback loop. a change to token n-1 has flow-on effects to how the [1x4x64x64] latent tensor is modified by all the tokens after it, - and as the next step is processed, all the tokens before it as well. intuitively, a token's conditioning effects "echo" throughout the whole length of the prompt. so even though the token at n-1 was being edited when what the user actually wanted was to edit the token at n, it nevertheless still had some non-negligible effect, in roughly the right direction, often enough that it seemed like it was working properly. --- ldm/invoke/conditioning.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 54092578a1..328167d783 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -77,8 +77,13 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed original_token_count = 0 edited_token_count = 0 - edit_opcodes = [] edit_options = [] + edit_opcodes = [] + # beginning of sequence + edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1)) + edit_options.append(None) + original_token_count += 1 + edited_token_count += 1 for fragment in flattened_prompt.children: if type(fragment) is CrossAttentionControlSubstitute: original_prompt.append(fragment.original) @@ -105,6 +110,12 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options.append(None) original_token_count += count edited_token_count += count + # end of sequence + edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1)) + edit_options.append(None) + original_token_count += 1 + edited_token_count += 1 + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens,