mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Implement compel prompt nodes for sdxl
This commit is contained in:
parent
e039771d07
commit
ada9b06e48
@ -2,7 +2,7 @@ from typing import Literal, Optional, Union, List, Annotated
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import (Blend, Conjunction,
|
from compel.prompt_parser import (Blend, Conjunction,
|
||||||
CrossAttentionControlSubstitute,
|
CrossAttentionControlSubstitute,
|
||||||
FlattenedPrompt, Fragment)
|
FlattenedPrompt, Fragment)
|
||||||
@ -165,170 +165,8 @@ class CompelInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: implement with compel package update
|
class SDXLPromptInvocationBase:
|
||||||
class SDXLCompelInvocation(BaseInvocation):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
|
||||||
|
|
||||||
type: Literal["sdxl_compel"] = "sdxl_compel"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
clip1: ClipField = Field(None, description="Clip to use")
|
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_clip(self, context, clip_field):
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.tokenizer.dict(),
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.text_encoder.dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader():
|
|
||||||
for lora in clip_field.loras:
|
|
||||||
lora_info = context.services.model_manager.get_model(
|
|
||||||
**lora.dict(exclude={"weight"}))
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
|
||||||
|
|
||||||
ti_list = []
|
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=clip_field.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
|
||||||
except ModelNotFoundException:
|
|
||||||
# print(e)
|
|
||||||
#import traceback
|
|
||||||
#print(traceback.format_exc())
|
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
compel = Compel(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
textual_inversion_manager=ti_manager,
|
|
||||||
dtype_for_device_getter=torch_dtype,
|
|
||||||
truncate_long_prompts=True, # TODO:
|
|
||||||
)
|
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
|
||||||
|
|
||||||
### TODO: pooled
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
self.prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
prompt_embeds = text_encoder(
|
|
||||||
text_input_ids.to(text_encoder.device),
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
c_pooled = prompt_embeds[0]
|
|
||||||
c = prompt_embeds.hidden_states[-2]
|
|
||||||
### TODO: pooled
|
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
# if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
del tokenizer
|
|
||||||
del text_encoder
|
|
||||||
del tokenizer_info
|
|
||||||
del text_encoder_info
|
|
||||||
del compel
|
|
||||||
|
|
||||||
return c.detach(), c_pooled.detach(), None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
c1, c1_pooled, ec1 = self.run_clip(context, self.clip1)
|
|
||||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2)
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
|
||||||
conditionings=[
|
|
||||||
SDXLConditioningInfo(
|
|
||||||
embeds=torch.cat([c1, c2], dim=-1),
|
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
extra_conditioning=ec1,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
class SDXLRawPromptInvocation(BaseInvocation):
|
|
||||||
"""Parse prompt using compel package to conditioning."""
|
|
||||||
|
|
||||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
style: str = Field(default="", description="Style prompt")
|
|
||||||
original_width: int = Field(1024, description="")
|
|
||||||
original_height: int = Field(1024, description="")
|
|
||||||
crop_top: int = Field(0, description="")
|
|
||||||
crop_left: int = Field(0, description="")
|
|
||||||
target_width: int = Field(1024, description="")
|
|
||||||
target_height: int = Field(1024, description="")
|
|
||||||
clip1: ClipField = Field(None, description="Clip to use")
|
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Prompt (Raw)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_clip(self, context, clip_field, prompt):
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
)
|
)
|
||||||
@ -380,7 +218,10 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
text_input_ids.to(text_encoder.device),
|
text_input_ids.to(text_encoder.device),
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
c_pooled = prompt_embeds[0]
|
if get_pooled:
|
||||||
|
c_pooled = prompt_embeds[0]
|
||||||
|
else:
|
||||||
|
c_pooled = None
|
||||||
c = prompt_embeds.hidden_states[-2]
|
c = prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
del tokenizer
|
del tokenizer
|
||||||
@ -388,15 +229,119 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
del tokenizer_info
|
del tokenizer_info
|
||||||
del text_encoder_info
|
del text_encoder_info
|
||||||
|
|
||||||
return c.detach(), c_pooled.detach(), None
|
return c, c_pooled, None
|
||||||
|
|
||||||
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
|
**clip_field.tokenizer.dict(),
|
||||||
|
)
|
||||||
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
|
**clip_field.text_encoder.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in clip_field.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}))
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
|
ti_list = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=clip_field.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
).context.model
|
||||||
|
)
|
||||||
|
except ModelNotFoundException:
|
||||||
|
# print(e)
|
||||||
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
|
compel = Compel(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
textual_inversion_manager=ti_manager,
|
||||||
|
dtype_for_device_getter=torch_dtype,
|
||||||
|
truncate_long_prompts=True, # TODO:
|
||||||
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
|
requires_pooled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
|
|
||||||
|
if context.services.configuration.log_tokenization:
|
||||||
|
# TODO: better logging for and syntax
|
||||||
|
for prompt_obj in conjunction.prompts:
|
||||||
|
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
|
||||||
|
|
||||||
|
# TODO: ask for optimizations? to not run text_encoder twice
|
||||||
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
if get_pooled:
|
||||||
|
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
||||||
|
else:
|
||||||
|
c_pooled = None
|
||||||
|
|
||||||
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
del tokenizer
|
||||||
|
del text_encoder
|
||||||
|
del tokenizer_info
|
||||||
|
del text_encoder_info
|
||||||
|
|
||||||
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||||
|
|
||||||
|
prompt: str = Field(default="", description="Prompt")
|
||||||
|
style: str = Field(default="", description="Style prompt")
|
||||||
|
original_width: int = Field(1024, description="")
|
||||||
|
original_height: int = Field(1024, description="")
|
||||||
|
crop_top: int = Field(0, description="")
|
||||||
|
crop_left: int = Field(0, description="")
|
||||||
|
target_width: int = Field(1024, description="")
|
||||||
|
target_height: int = Field(1024, description="")
|
||||||
|
clip1: ClipField = Field(None, description="Clip to use")
|
||||||
|
clip2: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Prompt (Compel)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip(context, self.clip1, self.prompt)
|
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip1, self.prompt, False)
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.prompt)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -426,10 +371,10 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = Field(1024, description="")
|
||||||
@ -443,7 +388,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
|||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Prompt (Raw)",
|
"title": "SDXL Refiner Prompt (Compel)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model"
|
||||||
@ -451,71 +396,9 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_clip(self, context, clip_field, prompt):
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.tokenizer.dict(),
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.text_encoder.dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader():
|
|
||||||
for lora in clip_field.loras:
|
|
||||||
lora_info = context.services.model_manager.get_model(
|
|
||||||
**lora.dict(exclude={"weight"}))
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
|
||||||
|
|
||||||
ti_list = []
|
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=clip_field.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
|
||||||
except ModelNotFoundException:
|
|
||||||
# print(e)
|
|
||||||
#import traceback
|
|
||||||
#print(traceback.format_exc())
|
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
prompt_embeds = text_encoder(
|
|
||||||
text_input_ids.to(text_encoder.device),
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
c_pooled = prompt_embeds[0]
|
|
||||||
c = prompt_embeds.hidden_states[-2]
|
|
||||||
|
|
||||||
del tokenizer
|
|
||||||
del text_encoder
|
|
||||||
del tokenizer_info
|
|
||||||
del text_encoder_info
|
|
||||||
|
|
||||||
return c.detach(), c_pooled.detach(), None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -544,6 +427,127 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
||||||
|
|
||||||
|
prompt: str = Field(default="", description="Prompt")
|
||||||
|
style: str = Field(default="", description="Style prompt")
|
||||||
|
original_width: int = Field(1024, description="")
|
||||||
|
original_height: int = Field(1024, description="")
|
||||||
|
crop_top: int = Field(0, description="")
|
||||||
|
crop_left: int = Field(0, description="")
|
||||||
|
target_width: int = Field(1024, description="")
|
||||||
|
target_height: int = Field(1024, description="")
|
||||||
|
clip1: ClipField = Field(None, description="Clip to use")
|
||||||
|
clip2: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Prompt (Raw)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip1, self.prompt, False)
|
||||||
|
if self.style.strip() == "":
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||||
|
else:
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||||
|
|
||||||
|
original_size = (self.original_height, self.original_width)
|
||||||
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + target_size
|
||||||
|
])
|
||||||
|
|
||||||
|
conditioning_data = ConditioningFieldData(
|
||||||
|
conditionings=[
|
||||||
|
SDXLConditioningInfo(
|
||||||
|
embeds=torch.cat([c1, c2], dim=-1),
|
||||||
|
pooled_embeds=c2_pooled,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
extra_conditioning=ec1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||||
|
|
||||||
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
|
original_width: int = Field(1024, description="")
|
||||||
|
original_height: int = Field(1024, description="")
|
||||||
|
crop_top: int = Field(0, description="")
|
||||||
|
crop_left: int = Field(0, description="")
|
||||||
|
aesthetic_score: float = Field(6.0, description="")
|
||||||
|
clip2: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Refiner Prompt (Raw)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||||
|
|
||||||
|
original_size = (self.original_height, self.original_width)
|
||||||
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + (self.aesthetic_score,)
|
||||||
|
])
|
||||||
|
|
||||||
|
conditioning_data = ConditioningFieldData(
|
||||||
|
conditionings=[
|
||||||
|
SDXLConditioningInfo(
|
||||||
|
embeds=c2,
|
||||||
|
pooled_embeds=c2_pooled,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
extra_conditioning=ec2, # or None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
|
Loading…
Reference in New Issue
Block a user