mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable long prompts, upgrade compel to enable .and() (concatenating prompts)
This commit is contained in:
parent
82231369d3
commit
cdcfda164d
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
from ...backend.prompting.conditioning import try_parse_legacy_blend
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
@ -13,7 +14,7 @@ from compel.prompt_parser import (
|
|||||||
Blend,
|
Blend,
|
||||||
CrossAttentionControlSubstitute,
|
CrossAttentionControlSubstitute,
|
||||||
FlattenedPrompt,
|
FlattenedPrompt,
|
||||||
Fragment,
|
Fragment, Conjunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -93,25 +94,22 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
|
||||||
|
if legacy_blend is not None:
|
||||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
conjunction = legacy_blend
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
else:
|
||||||
|
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.services.configuration.log_tokenization:
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
#if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,14 +126,22 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||||
) -> int:
|
) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
return max(
|
return max(
|
||||||
[
|
[
|
||||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
for c in blend.prompts
|
for p in blend.prompts
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif type(prompt) is Conjunction:
|
||||||
|
conjunction: Conjunction = prompt
|
||||||
|
return sum(
|
||||||
|
[
|
||||||
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
|
for p in conjunction.prompts
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -170,6 +176,22 @@ def get_tokens_for_prompt_object(
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def log_tokenization_for_conjunction(
|
||||||
|
c: Conjunction, tokenizer, display_label_prefix=None
|
||||||
|
):
|
||||||
|
display_label_prefix = display_label_prefix or ""
|
||||||
|
for i, p in enumerate(c.prompts):
|
||||||
|
if len(c.prompts)>1:
|
||||||
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||||
|
else:
|
||||||
|
this_display_label_prefix = display_label_prefix
|
||||||
|
log_tokenization_for_prompt_object(
|
||||||
|
p,
|
||||||
|
tokenizer,
|
||||||
|
display_label_prefix=this_display_label_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
def log_tokenization_for_prompt_object(
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||||
):
|
):
|
||||||
|
@ -4,6 +4,7 @@ import random
|
|||||||
import einops
|
import einops
|
||||||
from typing import Literal, Optional, Union, List
|
from typing import Literal, Optional, Union, List
|
||||||
|
|
||||||
|
from compel import Compel
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
@ -233,6 +234,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
|
||||||
|
compel = Compel(
|
||||||
|
tokenizer=model.tokenizer,
|
||||||
|
text_encoder=model.text_encoder,
|
||||||
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
|
dtype_for_device_getter=torch_dtype,
|
||||||
|
truncate_long_prompts=False,
|
||||||
|
)
|
||||||
|
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
|
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||||
for match in re.finditer(prompt_parser, text)
|
for match in re.finditer(prompt_parser, text)
|
||||||
]
|
]
|
||||||
|
if len(parsed_prompts) == 0:
|
||||||
|
return []
|
||||||
if skip_normalize:
|
if skip_normalize:
|
||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel~=1.1.5",
|
"compel>=1.2.1",
|
||||||
"controlnet-aux>=0.0.4",
|
"controlnet-aux>=0.0.4",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"datasets",
|
"datasets",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user