mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Separate prompt to sdxl and sdxl-refiner, add denoising start-end fields, add l2l node(supports both sdxl and sdxl-refiner), add fp32 to vae encode
This commit is contained in:
@ -36,6 +36,7 @@ class BasicConditioningInfo:
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
ConditioningInfoType = Annotated[
|
||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
@ -300,6 +301,12 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip1: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
@ -385,11 +392,20 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
)
|
||||
]
|
||||
@ -404,6 +420,124 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def run_clip(self, context, clip_field, prompt):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
c_pooled = prompt_embeds[0]
|
||||
c = prompt_embeds.hidden_states[-2]
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
return c.detach(), c_pooled.detach(), None
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
|
Reference in New Issue
Block a user