Add concatenation logic to sdxl-compel node

This commit is contained in:
Sergey Borisov 2023-08-31 04:46:37 +03:00
parent d5267357b1
commit 7e46f8f1c5

View File

@ -273,6 +273,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
concat_style: bool = InputField(default=True, description="Enable concatenation 'prompt' to 'style' field")
original_width: int = InputField(default=1024, description="") original_width: int = InputField(default=1024, description="")
original_height: int = InputField(default=1024, description="") original_height: int = InputField(default=1024, description="")
crop_top: int = InputField(default=0, description="") crop_top: int = InputField(default=0, description="")
@ -284,17 +285,20 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
if self.style == "":
style = self.prompt
elif self.prompt == "":
style = self.style
else: # if both style and prompt not empty
style = self.prompt + " " + self.style
c1, c1_pooled, ec1 = self.run_clip_compel( c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
) )
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel( c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True context, self.clip2, style, True, "lora_te2_", zero_on_empty=True
) )
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
)
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)