enable long prompts, upgrade compel to enable .and() (concatenating prompts)

This commit is contained in:
Damian Stewart 2023-06-04 15:30:54 +02:00
parent 82231369d3
commit cdcfda164d
4 changed files with 52 additions and 18 deletions

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.prompting.conditioning import try_parse_legacy_blend
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@ -13,7 +14,7 @@ from compel.prompt_parser import (
Blend, Blend,
CrossAttentionControlSubstitute, CrossAttentionControlSubstitute,
FlattenedPrompt, FlattenedPrompt,
Fragment, Fragment, Conjunction,
) )
@ -93,25 +94,22 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager, textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO: truncate_long_prompts=False,
) )
# TODO: support legacy blend? legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
if legacy_blend is not None:
conjunction = Compel.parse_prompt_string(prompt_str) conjunction = legacy_blend
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] else:
conjunction = Compel.parse_prompt_string(prompt_str)
if context.services.configuration.log_tokenization: if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
@ -128,14 +126,22 @@ class CompelInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max(
[ [
get_max_token_count(tokenizer, c, truncate_if_too_long) get_max_token_count(tokenizer, p, truncate_if_too_long)
for c in blend.prompts for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
] ]
) )
else: else:
@ -170,6 +176,22 @@ def get_tokens_for_prompt_object(
return tokens return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object( def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
): ):

View File

@ -4,6 +4,7 @@ import random
import einops import einops
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from compel import Compel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
@ -233,6 +234,15 @@ class TextToLatentsInvocation(BaseInvocation):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
compel = Compel(
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,
c, c,

View File

@ -38,7 +38,7 @@ def get_uc_and_c_and_ec(prompt_string,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=False,
) )
config = get_invokeai_config() config = get_invokeai_config()
# get rid of any newline characters # get rid of any newline characters
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) (match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text) for match in re.finditer(prompt_parser, text)
] ]
if len(parsed_prompts) == 0:
return []
if skip_normalize: if skip_normalize:
return parsed_prompts return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts)) weight_sum = sum(map(lambda x: x[1], parsed_prompts))

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=1.1.5", "compel>=1.2.1",
"controlnet-aux>=0.0.4", "controlnet-aux>=0.0.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",