mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restrict to 75 tokens and correctly handle blends
This commit is contained in:
parent
5d20f47993
commit
64339af2dc
@ -36,13 +36,16 @@ Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
return prompt, negative_prompt
|
||||
|
||||
|
||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
for x in parsed_prompt.children]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
|
||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
eos_token_index = 1
|
||||
if type(parsed_prompt) is not Blend:
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
|
||||
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
|
||||
eos_token_index = len(longest_token_sequence)+1
|
||||
else:
|
||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||
eos_token_index = len(tokens)+1
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user