mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restrict to 75 tokens and correctly handle blends
This commit is contained in:
parent
5d20f47993
commit
64339af2dc
@ -36,13 +36,16 @@ Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
|||||||
return prompt, negative_prompt
|
return prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||||
text_fragments = [x.text if type(x) is Fragment else
|
text_fragments = [x.text if type(x) is Fragment else
|
||||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||||
str(x))
|
str(x))
|
||||||
for x in parsed_prompt.children]
|
for x in parsed_prompt.children]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||||
|
if truncate_if_too_long:
|
||||||
|
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
|
||||||
|
tokens = tokens[0:max_tokens_length]
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
|
|||||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||||
cac_args = None
|
cac_args = None
|
||||||
|
|
||||||
eos_token_index = 1
|
if type(parsed_prompt) is Blend:
|
||||||
if type(parsed_prompt) is not Blend:
|
blend: Blend = parsed_prompt
|
||||||
|
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
|
||||||
|
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
|
||||||
|
eos_token_index = len(longest_token_sequence)+1
|
||||||
|
else:
|
||||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||||
eos_token_index = len(tokens)+1
|
eos_token_index = len(tokens)+1
|
||||||
return (
|
return (
|
||||||
|
Loading…
Reference in New Issue
Block a user