From bf3ee1fefa57e786cb68ec5ddf7a1801bf626c92 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 10:48:45 -0500 Subject: [PATCH] Update compel nodes to accept an optional prompt mask. --- invokeai/app/invocations/compel.py | 35 +++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c23dd3d908..6df3301362 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ConditioningField, + FieldDescriptions, + Input, + InputField, + MaskField, + OutputField, + UIComponent, +) from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -36,7 +44,7 @@ from .model import CLIPField title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -51,6 +59,9 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -117,8 +128,12 @@ class CompelInvocation(BaseInvocation): ) conditioning_name = context.conditioning.save(conditioning_data) - - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) class SDXLPromptInvocationBase: @@ -232,7 +247,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -255,6 +270,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): target_height: int = InputField(default=1024, description="") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -317,7 +335,12 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) @invocation(