mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Prompting: enable long prompts and compel's new .and()
concatenating feature (#3497)
this PR adds long prompt support and enables compel's new `.and()` concatenation feature which improves image quality especially with SD2.1 example of a long prompt: > a moist sloppy pindlesackboy sloppy hamblin' bogomadong, Clem Fandango is pissed-off, Wario's Woods in background, making a noise like ga-woink-a ![000075 6dfd7adf 466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/051608b6-8d52-463b-af10-04b695cda9c1) the same prompt broken into fragments and concatenated using `.and()` (syntax works like `.blend()`): ``` ("a moist sloppy pindlesackboy sloppy hamblin' bogomadong", "Clem Fandango is pissed-off", "Wario's Woods in background", "making a noise like ga-woink-a").and() ``` ![000076 68b1c320 466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/3fee291f-5562-40f9-9c3c-a73765fc893a) and a less silly example: > A dream of a distant galaxy, by Caspar David Friedrich, matte painting, trending on artstation, HQ ![000129 1b33b559 2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/d4113756-ed0d-49cd-bb2e-a2fc4a09e0af) the same prompt broken into two fragments and concatenated: ``` ("A dream of a distant galaxy, by Caspar David Friedrich, matte painting", "trending on artstation, HQ").and() ``` ![000128 b5d5cd62 2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/c373c009-05db-4c42-8a1d-c89fbdb334ec) as with `.blend()` you can also weight the parts eg `("a man eating an apple", "sitting on the roof of a car", "high quality, trending on artstation, 8K UHD").and(1, 0.5, 0.5)` which will assign weight `1` to `a man eating an apple` and `0.5` to `sitting on the roof of a car` and `high quality, trending on artstation, 8K UHD`.
This commit is contained in:
commit
25b8dd340a
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from ...backend.prompting.conditioning import try_parse_legacy_blend
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
@ -13,7 +14,7 @@ from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
Fragment, Conjunction,
|
||||
)
|
||||
|
||||
|
||||
@ -93,25 +94,22 @@ class CompelInvocation(BaseInvocation):
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
|
||||
if legacy_blend is not None:
|
||||
conjunction = legacy_blend
|
||||
else:
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
@ -128,14 +126,22 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
@ -170,6 +176,22 @@ def get_tokens_for_prompt_object(
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts)>1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
|
@ -4,6 +4,7 @@ import random
|
||||
import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
from compel import Compel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
@ -233,6 +234,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
|
@ -38,7 +38,7 @@ def get_uc_and_c_and_ec(prompt_string,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
|
||||
config = get_invokeai_config()
|
||||
|
||||
# get rid of any newline characters
|
||||
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
|
@ -38,7 +38,7 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=1.1.5",
|
||||
"compel>=1.2.1",
|
||||
"controlnet-aux>=0.0.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
|
Loading…
Reference in New Issue
Block a user