mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes, zero tensor for empty negative prompt, remove raw prompt node
This commit is contained in:
parent
9aaf67c5b4
commit
b0738b7f70
@ -185,7 +185,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix, zero_on_empty):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -195,83 +195,22 @@ class SDXLPromptInvocationBase:
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
# return zero on empty
|
||||||
for lora in clip_field.loras:
|
if prompt == "" and zero_on_empty:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
cpu_text_encoder = text_encoder_info.context.model
|
||||||
yield (lora_info.context.model, lora.weight)
|
c = torch.zeros(
|
||||||
del lora_info
|
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
||||||
return
|
dtype=text_encoder_info.context.cache.precision,
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
|
||||||
|
|
||||||
ti_list = []
|
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
(
|
|
||||||
name,
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=clip_field.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
context=context,
|
|
||||||
).context.model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except ModelNotFoundException:
|
|
||||||
# print(e)
|
|
||||||
# import traceback
|
|
||||||
# print(traceback.format_exc())
|
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
|
||||||
|
|
||||||
with ModelPatcher.apply_lora(
|
|
||||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
), ModelPatcher.apply_clip_skip(
|
|
||||||
text_encoder_info.context.model, clip_field.skipped_layers
|
|
||||||
), text_encoder_info as text_encoder:
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
prompt_embeds = text_encoder(
|
|
||||||
text_input_ids.to(text_encoder.device),
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
)
|
||||||
if get_pooled:
|
if get_pooled:
|
||||||
c_pooled = prompt_embeds[0]
|
c_pooled = torch.zeros(
|
||||||
|
(1, cpu_text_encoder.config.hidden_size),
|
||||||
|
dtype=c.dtype,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
c = prompt_embeds.hidden_states[-2]
|
|
||||||
|
|
||||||
del tokenizer
|
|
||||||
del text_encoder
|
|
||||||
del tokenizer_info
|
|
||||||
del text_encoder_info
|
|
||||||
|
|
||||||
c = c.detach().to("cpu")
|
|
||||||
if c_pooled is not None:
|
|
||||||
c_pooled = c_pooled.detach().to("cpu")
|
|
||||||
|
|
||||||
return c, c_pooled, None
|
return c, c_pooled, None
|
||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.tokenizer.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.text_encoder.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||||
@ -375,11 +314,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=False)
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True)
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||||
|
|
||||||
|
print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -434,118 +375,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
|
||||||
conditionings=[
|
|
||||||
SDXLConditioningInfo(
|
|
||||||
embeds=c2,
|
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec2, # or None
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|
||||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
|
||||||
|
|
||||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
style: str = Field(default="", description="Style prompt")
|
|
||||||
original_width: int = Field(1024, description="")
|
|
||||||
original_height: int = Field(1024, description="")
|
|
||||||
crop_top: int = Field(0, description="")
|
|
||||||
crop_left: int = Field(0, description="")
|
|
||||||
target_width: int = Field(1024, description="")
|
|
||||||
target_height: int = Field(1024, description="")
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
|
||||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
|
||||||
if self.style.strip() == "":
|
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
|
||||||
else:
|
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
|
||||||
target_size = (self.target_height, self.target_width)
|
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
|
||||||
conditionings=[
|
|
||||||
SDXLConditioningInfo(
|
|
||||||
embeds=torch.cat([c1, c2], dim=-1),
|
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec1,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|
||||||
"""Parse prompt using compel package to conditioning."""
|
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
|
||||||
original_width: int = Field(1024, description="")
|
|
||||||
original_height: int = Field(1024, description="")
|
|
||||||
crop_top: int = Field(0, description="")
|
|
||||||
crop_left: int = Field(0, description="")
|
|
||||||
aesthetic_score: float = Field(6.0, description="")
|
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Refiner Prompt (Raw)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
@ -386,8 +386,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning: torch.Tensor,
|
conditioning_data,
|
||||||
conditioning: torch.Tensor,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
@ -444,8 +443,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
conditioning_data,
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user