diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index d0b55cd185..7fd101a3a0 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -185,7 +185,7 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: - def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix): + def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix, zero_on_empty): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -195,82 +195,21 @@ class SDXLPromptInvocationBase: context=context, ) - def _lora_loader(): - for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) - yield (lora_info.context.model, lora.weight) - del lora_info - return - - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - - with ModelPatcher.apply_lora( - text_encoder_info.context.model, _lora_loader(), lora_prefix - ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( - tokenizer, - ti_manager, - ), ModelPatcher.apply_clip_skip( - text_encoder_info.context.model, clip_field.skipped_layers - ), text_encoder_info as text_encoder: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + # return zero on empty + if prompt == "" and zero_on_empty: + cpu_text_encoder = text_encoder_info.context.model + c = torch.zeros( + (1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size), + dtype=text_encoder_info.context.cache.precision, ) if get_pooled: - c_pooled = prompt_embeds[0] + c_pooled = torch.zeros( + (1, cpu_text_encoder.config.hidden_size), + dtype=c.dtype, + ) else: c_pooled = None - c = prompt_embeds.hidden_states[-2] - - del tokenizer - del text_encoder - del tokenizer_info - del text_encoder_info - - c = c.detach().to("cpu") - if c_pooled is not None: - c_pooled = c_pooled.detach().to("cpu") - - return c, c_pooled, None - - def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), - context=context, - ) + return c, c_pooled, None def _lora_loader(): for lora in clip_field.loras: @@ -375,11 +314,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_") + c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=False) if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_") + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True) else: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_") + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True) + + print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -434,118 +375,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: # TODO: if there will appear lora for refiner - write proper prefix - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "") - - original_size = (self.original_height, self.original_width) - crop_coords = (self.crop_top, self.crop_left) - - add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) - - conditioning_data = ConditioningFieldData( - conditionings=[ - SDXLConditioningInfo( - embeds=c2, - pooled_embeds=c2_pooled, - add_time_ids=add_time_ids, - extra_conditioning=ec2, # or None - ) - ] - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): - """Pass unmodified prompt to conditioning without compel processing.""" - - type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt" - - prompt: str = Field(default="", description="Prompt") - style: str = Field(default="", description="Style prompt") - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - target_width: int = Field(1024, description="") - target_height: int = Field(1024, description="") - clip: ClipField = Field(None, description="Clip to use") - clip2: ClipField = Field(None, description="Clip2 to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, - } - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_") - if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_") - else: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_") - - original_size = (self.original_height, self.original_width) - crop_coords = (self.crop_top, self.crop_left) - target_size = (self.target_height, self.target_width) - - add_time_ids = torch.tensor([original_size + crop_coords + target_size]) - - conditioning_data = ConditioningFieldData( - conditionings=[ - SDXLConditioningInfo( - embeds=torch.cat([c1, c2], dim=-1), - pooled_embeds=c2_pooled, - add_time_ids=add_time_ids, - extra_conditioning=ec1, - ) - ] - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): - """Parse prompt using compel package to conditioning.""" - - type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt" - - style: str = Field(default="", description="Style prompt") # TODO: ? - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - aesthetic_score: float = Field(6.0, description="") - clip2: ClipField = Field(None, description="Clip to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Refiner Prompt (Raw)", - "tags": ["prompt", "compel"], - "type_hints": {"model": "model"}, - }, - } - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> CompelOutput: - # TODO: if there will appear lora for refiner - write proper prefix - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "") + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index b906719923..4ff8c5abc7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -386,8 +386,7 @@ class InvokeAIDiffuserComponent: self, x: torch.Tensor, sigma, - unconditioning: torch.Tensor, - conditioning: torch.Tensor, + conditioning_data, **kwargs, ): # low-memory sequential path @@ -444,8 +443,7 @@ class InvokeAIDiffuserComponent: self, x: torch.Tensor, sigma, - unconditioning, - conditioning, + conditioning_data, cross_attention_control_types_to_do, **kwargs, ):