From 64339af2dcda21aa79843f7f1bf5d1ede9d52406 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Dec 2022 22:03:21 +0100 Subject: [PATCH] restrict to 75 tokens and correctly handle blends --- ldm/invoke/conditioning.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index aba329ccde..cf6e84ec60 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -36,13 +36,16 @@ Union[FlattenedPrompt, Blend], FlattenedPrompt): return prompt, negative_prompt -def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]: +def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]: text_fragments = [x.text if type(x) is Fragment else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x)) for x in parsed_prompt.children] text = " ".join(text_fragments) tokens = model.cond_stage_model.tokenizer.tokenize(text) + if truncate_if_too_long: + max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75 + tokens = tokens[0:max_tokens_length] return tokens @@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p ">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") cac_args = None - eos_token_index = 1 - if type(parsed_prompt) is not Blend: + if type(parsed_prompt) is Blend: + blend: Blend = parsed_prompt + all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts] + longest_token_sequence = max(all_token_sequences, key=lambda t: len(t)) + eos_token_index = len(longest_token_sequence)+1 + else: tokens = get_tokens_for_prompt(model, parsed_prompt) eos_token_index = len(tokens)+1 return (