restrict to 75 tokens and correctly handle blends

This commit is contained in:
Damian Stewart 2022-12-14 22:03:21 +01:00 committed by Lincoln Stein
parent 5d20f47993
commit 64339af2dc

View File

@ -36,13 +36,16 @@ Union[FlattenedPrompt, Blend], FlattenedPrompt):
return prompt, negative_prompt return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]: def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
text_fragments = [x.text if type(x) is Fragment else text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x)) str(x))
for x in parsed_prompt.children] for x in parsed_prompt.children]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text) tokens = model.cond_stage_model.tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens return tokens
@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") ">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None cac_args = None
eos_token_index = 1 if type(parsed_prompt) is Blend:
if type(parsed_prompt) is not Blend: blend: Blend = parsed_prompt
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
eos_token_index = len(longest_token_sequence)+1
else:
tokens = get_tokens_for_prompt(model, parsed_prompt) tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1 eos_token_index = len(tokens)+1
return ( return (