Update compel nodes to accept an optional prompt mask.

This commit is contained in:
Ryan Dick 2024-03-08 10:48:45 -05:00
parent 1da8423f04
commit bf3ee1fefa

View File

@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -36,7 +44,7 @@ from .model import CLIPField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -51,6 +59,9 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -117,8 +128,12 @@ class CompelInvocation(BaseInvocation):
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)
class SDXLPromptInvocationBase:
@ -232,7 +247,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -255,6 +270,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -317,7 +335,12 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)
@invocation(