Prompting: enable long prompts and compel's new .and() concatenating feature (#3497)

this PR adds long prompt support and enables compel's new `.and()`
concatenation feature which improves image quality especially with SD2.1

example of a long prompt:
> a moist sloppy pindlesackboy sloppy hamblin' bogomadong, Clem Fandango
is pissed-off, Wario's Woods in background, making a noise like
ga-woink-a
![000075 6dfd7adf
466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/051608b6-8d52-463b-af10-04b695cda9c1)

the same prompt broken into fragments and concatenated using `.and()`
(syntax works like `.blend()`):
```
("a moist sloppy pindlesackboy sloppy hamblin' bogomadong", 
"Clem Fandango is pissed-off", 
"Wario's Woods in background", 
"making a noise like ga-woink-a").and()
```
![000076 68b1c320
466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/3fee291f-5562-40f9-9c3c-a73765fc893a)


and a less silly example:

> A dream of a distant galaxy, by Caspar David Friedrich, matte
painting, trending on artstation, HQ
![000129 1b33b559
2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/d4113756-ed0d-49cd-bb2e-a2fc4a09e0af)

the same prompt broken into two fragments and concatenated:
```
("A dream of a distant galaxy, by Caspar David Friedrich, matte painting", 
"trending on artstation, HQ").and()
```
![000128 b5d5cd62
2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/c373c009-05db-4c42-8a1d-c89fbdb334ec)

as with `.blend()` you can also weight the parts eg `("a man eating an
apple", "sitting on the roof of a car", "high quality, trending on
artstation, 8K UHD").and(1, 0.5, 0.5)` which will assign weight `1` to
`a man eating an apple` and `0.5` to `sitting on the roof of a car` and
`high quality, trending on artstation, 8K UHD`.
This commit is contained in:
blessedcoolant 2023-06-05 04:53:08 +12:00 committed by GitHub
commit 25b8dd340a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 18 deletions

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.prompting.conditioning import try_parse_legacy_blend
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@ -13,7 +14,7 @@ from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
Fragment, Conjunction,
)
@ -93,25 +94,22 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
truncate_long_prompts=False,
)
# TODO: support legacy blend?
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
if legacy_blend is not None:
conjunction = legacy_blend
else:
conjunction = Compel.parse_prompt_string(prompt_str)
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@ -128,14 +126,22 @@ class CompelInvocation(BaseInvocation):
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
]
)
else:
@ -170,6 +176,22 @@ def get_tokens_for_prompt_object(
return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):

View File

@ -4,6 +4,7 @@ import random
import einops
from typing import Literal, Optional, Union, List
from compel import Compel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field, validator
@ -233,6 +234,15 @@ class TextToLatentsInvocation(BaseInvocation):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
compel = Compel(
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
conditioning_data = ConditioningData(
uc,
c,

View File

@ -38,7 +38,7 @@ def get_uc_and_c_and_ec(prompt_string,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
config = get_invokeai_config()
# get rid of any newline characters
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text)
]
if len(parsed_prompts) == 0:
return []
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=1.1.5",
"compel>=1.2.1",
"controlnet-aux>=0.0.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",