diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 7c3ce7a819..d0b55cd185 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -37,6 +37,10 @@ class BasicConditioningInfo: # weight: float # mode: ConditioningAlgo + def to(self, device, dtype=None): + self.embeds = self.embeds.to(device=device, dtype=dtype) + return self + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): @@ -44,6 +48,11 @@ class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor add_time_ids: torch.Tensor + def to(self, device, dtype=None): + self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) + self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) + return super().to(device=device, dtype=dtype) + ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")] diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6e2e0838bc..a63f98de24 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -174,11 +174,11 @@ class TextToLatentsInvocation(BaseInvocation): unet, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) - extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning + c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + extra_conditioning_info = c.extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) + uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 624d47ff64..8a7616f1f1 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -212,8 +212,8 @@ class ControlNetData: @dataclass class ConditioningData: - unconditioned_embeddings: torch.Tensor - text_embeddings: torch.Tensor + unconditioned_embeddings: Any # TODO: type + text_embeddings: Any # TODO: type guidance_scale: Union[float, List[float]] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). @@ -392,48 +392,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): submodels.append(value) return submodels - def image_from_embeddings( - self, - latents: torch.Tensor, - num_inference_steps: int, - conditioning_data: ConditioningData, - *, - noise: torch.Tensor, - callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - r""" - Function invoked when calling the pipeline for generation. - - :param conditioning_data: - :param latents: Pre-generated un-noised latents, to be used as inputs for - image generation. Can be used to tweak the same generation with different prompts. - :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - :param noise: Noise to add to the latents, sampled from a Gaussian distribution. - :param callback: - :param run_id: - """ - result_latents, result_attention_map_saver = self.latents_from_embeddings( - latents, - num_inference_steps, - conditioning_data, - noise=noise, - run_id=run_id, - callback=callback, - ) - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput( - images=image, - nsfw_content_detected=[], - attention_map_saver=result_attention_map_saver, - ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - def latents_from_embeddings( self, latents: torch.Tensor, @@ -492,13 +450,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): extra_conditioning_info=extra_conditioning_info, step_count=len(self.scheduler.timesteps), ): - yield PipelineIntermediateState( - run_id=run_id, - step=-1, - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - batch_size = latents.shape[0] batched_t = torch.full( (batch_size,), @@ -506,8 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): dtype=timesteps.dtype, device=self._model_group.device_for(self.unet), ) + #latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_t) + yield PipelineIntermediateState( + run_id=run_id, + step=-1, + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + ) + attention_map_saver: Optional[AttentionMapSaver] = None # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): @@ -569,95 +528,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # TODO: should this scaling happen here or inside self._unet_forward? # i.e. before or after passing it to InvokeAIDiffuserComponent - unet_latent_input = self.scheduler.scale_model_input(latents, timestep) + latent_model_input = self.scheduler.scale_model_input(latents, timestep) # default is no controlnet, so set controlnet processing output to None - down_block_res_samples, mid_block_res_sample = None, None - + controlnet_down_block_samples, controlnet_mid_block_sample = None, None if control_data is not None: - # control_data should be type List[ControlNetData] - # this loop covers both ControlNet (one ControlNetData in list) - # and MultiControlNet (multiple ControlNetData in list) - for i, control_datum in enumerate(control_data): - control_mode = control_datum.control_mode - # soft_injection and cfg_injection are the two ControlNet control_mode booleans - # that are combined at higher level to make control_mode enum - # soft_injection determines whether to do per-layer re-weighting adjustment (if True) - # or default weighting (if False) - soft_injection = control_mode == "more_prompt" or control_mode == "more_control" - # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) - # or the default both conditional and unconditional (if False) - cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" + controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step( + control_data=control_data, + sample=latent_model_input, + timestep=timestep, + step_index=step_index, + total_step_count=total_step_count, + conditioning_data=conditioning_data, + ) - first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) - last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) - # only apply controlnet if current step is within the controlnet's begin/end step range - if step_index >= first_control_step and step_index <= last_control_step: - if cfg_injection: - control_latent_input = unet_latent_input - else: - # expand the latents input to control model if doing classifier free guidance - # (which I think for now is always true, there is conditional elsewhere that stops execution if - # classifier_free_guidance is <= 1.0 ?) - control_latent_input = torch.cat([unet_latent_input] * 2) - - if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - encoder_hidden_states = conditioning_data.text_embeddings - encoder_attention_mask = None - else: - ( - encoder_hidden_states, - encoder_attention_mask, - ) = self.invokeai_diffuser._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings, - ) - if isinstance(control_datum.weight, list): - # if controlnet has multiple weights, use the weight for the current step - controlnet_weight = control_datum.weight[step_index] - else: - # if controlnet has a single weight, use it for all steps - controlnet_weight = control_datum.weight - - # controlnet(s) inference - down_samples, mid_sample = control_datum.model( - sample=control_latent_input, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=control_datum.image_tensor, - conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale - encoder_attention_mask=encoder_attention_mask, - guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel - return_dict=False, - ) - if cfg_injection: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # prepend zeros for unconditional batch - down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] - mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) - - if down_block_res_samples is None and mid_block_res_sample is None: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - # add controlnet outputs together if have multiple controlnets - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - - # predict the noise residual - noise_pred = self.invokeai_diffuser.do_diffusion_step( - x=unet_latent_input, - sigma=t, - unconditioning=conditioning_data.unconditioned_embeddings, - conditioning=conditioning_data.text_embeddings, - unconditional_guidance_scale=conditioning_data.guidance_scale, + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( + sample=latent_model_input, + timestep=t, # TODO: debug how handled batched and non batched timesteps step_index=step_index, total_step_count=total_step_count, - down_block_additional_residuals=down_block_res_samples, # from controlnet(s) - mid_block_additional_residual=mid_block_res_sample, # from controlnet(s) + conditioning_data=conditioning_data, + + # extra: + down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s) + mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s) + ) + + guidance_scale = conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[step_index] + + noise_pred = self.invokeai_diffuser._combine( + uc_noise_pred, + c_noise_pred, + guidance_scale, ) # compute the previous noisy sample x_t -> x_t-1 @@ -738,41 +642,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): callback, ) - def img2img_from_latents_and_embeddings( - self, - initial_latents, - num_inference_steps, - conditioning_data: ConditioningData, - strength, - noise: torch.Tensor, - run_id=None, - callback=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength) - result_latents, result_attention_maps = self.latents_from_embeddings( - latents=initial_latents - if strength < 1.0 - else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype), - num_inference_steps=num_inference_steps, - conditioning_data=conditioning_data, - timesteps=timesteps, - noise=noise, - run_id=run_id, - callback=callback, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput( - images=image, - nsfw_content_detected=[], - attention_map_saver=result_attention_maps, - ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) assert img2img_pipeline.scheduler is self.scheduler @@ -877,7 +746,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): nsfw_content_detected=[], attention_map_saver=result_attention_maps, ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) + return self.check_for_safety(output, dtype=self.unet.dtype) def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): init_image = init_image.to(device=device, dtype=dtype) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index c01cf82c57..b906719923 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from dataclasses import dataclass -from math import ceil +import math from typing import Any, Callable, Dict, Optional, Union, List import numpy as np @@ -127,33 +127,119 @@ class InvokeAIDiffuserComponent: for _, module in tokens_cross_attention_modules: module.set_attention_slice_calculated_callback(None) - def do_diffusion_step( + def do_controlnet_step( self, - x: torch.Tensor, - sigma: torch.Tensor, - unconditioning: Union[torch.Tensor, dict], - conditioning: Union[torch.Tensor, dict], - # unconditional_guidance_scale: float, - unconditional_guidance_scale: Union[float, List[float]], - step_index: Optional[int] = None, - total_step_count: Optional[int] = None, + control_data, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_step_count: int, + conditioning_data, + ): + down_block_res_samples, mid_block_res_sample = None, None + + # control_data should be type List[ControlNetData] + # this loop covers both ControlNet (one ControlNetData in list) + # and MultiControlNet (multiple ControlNetData in list) + for i, control_datum in enumerate(control_data): + control_mode = control_datum.control_mode + # soft_injection and cfg_injection are the two ControlNet control_mode booleans + # that are combined at higher level to make control_mode enum + # soft_injection determines whether to do per-layer re-weighting adjustment (if True) + # or default weighting (if False) + soft_injection = control_mode == "more_prompt" or control_mode == "more_control" + # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) + # or the default both conditional and unconditional (if False) + cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" + + first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) + last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) + # only apply controlnet if current step is within the controlnet's begin/end step range + if step_index >= first_control_step and step_index <= last_control_step: + if cfg_injection: + sample_model_input = sample + else: + # expand the latents input to control model if doing classifier free guidance + # (which I think for now is always true, there is conditional elsewhere that stops execution if + # classifier_free_guidance is <= 1.0 ?) + sample_model_input = torch.cat([sample] * 2) + + added_cond_kwargs = None + + if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned + if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + added_cond_kwargs = { + "text_embeds": conditioning_data.text_embeddings.pooled_embeds, + "time_ids": conditioning_data.text_embeddings.add_time_ids, + } + encoder_hidden_states = conditioning_data.text_embeddings.embeds + encoder_attention_mask = None + else: + if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + added_cond_kwargs = { + "text_embeds": torch.cat([ + # TODO: how to pad? just by zeros? or even truncate? + conditioning_data.unconditioned_embeddings.pooled_embeds, + conditioning_data.text_embeddings.pooled_embeds, + ], dim=0), + "time_ids": torch.cat([ + conditioning_data.unconditioned_embeddings.add_time_ids, + conditioning_data.text_embeddings.add_time_ids, + ], dim=0), + } + ( + encoder_hidden_states, + encoder_attention_mask, + ) = self._concat_conditionings_for_batch( + conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.text_embeddings.embeds, + ) + if isinstance(control_datum.weight, list): + # if controlnet has multiple weights, use the weight for the current step + controlnet_weight = control_datum.weight[step_index] + else: + # if controlnet has a single weight, use it for all steps + controlnet_weight = control_datum.weight + + # controlnet(s) inference + down_samples, mid_sample = control_datum.model( + sample=sample_model_input, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=control_datum.image_tensor, + conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale + encoder_attention_mask=encoder_attention_mask, + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + ) + if cfg_injection: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # prepend zeros for unconditional batch + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + if down_block_res_samples is None and mid_block_res_sample is None: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def do_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + conditioning_data, # TODO: type + step_index: int, + total_step_count: int, **kwargs, ): - """ - :param x: current latents - :param sigma: aka t, passed to the internal model to control how much denoising will occur - :param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] - :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] - :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has - :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas . - :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. - """ - - if isinstance(unconditional_guidance_scale, list): - guidance_scale = unconditional_guidance_scale[step_index] - else: - guidance_scale = unconditional_guidance_scale - cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: @@ -163,25 +249,15 @@ class InvokeAIDiffuserComponent: ) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 - wants_hybrid_conditioning = isinstance(conditioning, dict) - if wants_hybrid_conditioning: - unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( - x, - sigma, - unconditioning, - conditioning, - **kwargs, - ) - elif wants_cross_attention_control: + if wants_cross_attention_control: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_cross_attention_controlled_conditioning( - x, - sigma, - unconditioning, - conditioning, + sample, + timestep, + conditioning_data, cross_attention_control_types_to_do, **kwargs, ) @@ -190,10 +266,9 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - x, - sigma, - unconditioning, - conditioning, + sample, + timestep, + conditioning_data, **kwargs, ) @@ -202,21 +277,13 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - x, - sigma, - unconditioning, - conditioning, + sample, + timestep, + conditioning_data, **kwargs, ) - combined_next_x = self._combine( - # unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale - unconditioned_next_x, - conditioned_next_x, - guidance_scale, - ) - - return combined_next_x + return unconditioned_next_x, conditioned_next_x def do_latent_postprocessing( self, @@ -281,17 +348,35 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): + def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning) + added_cond_kwargs = None + if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + added_cond_kwargs = { + "text_embeds": torch.cat([ + # TODO: how to pad? just by zeros? or even truncate? + conditioning_data.unconditioned_embeddings.pooled_embeds, + conditioning_data.text_embeddings.pooled_embeds, + ], dim=0), + "time_ids": torch.cat([ + conditioning_data.unconditioned_embeddings.add_time_ids, + conditioning_data.text_embeddings.add_time_ids, + ], dim=0), + } + + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( + conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.text_embeddings.embeds + ) both_results = self.model_forward_callback( x_twice, sigma_twice, both_conditionings, encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs=added_cond_kwargs, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) @@ -320,46 +405,41 @@ class InvokeAIDiffuserComponent: if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + added_cond_kwargs = None + is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" + if is_sdxl: + added_cond_kwargs = { + "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, + "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + } + unconditioned_next_x = self.model_forward_callback( x, sigma, - unconditioning, + conditioning_data.unconditioned_embeddings.embeds, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, + added_cond_kwargs=added_cond_kwargs, **kwargs, ) + + if is_sdxl: + added_cond_kwargs = { + "text_embeds": conditioning_data.text_embeddings.pooled_embeds, + "time_ids": conditioning_data.text_embeddings.add_time_ids, + } + conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning, + conditioning_data.text_embeddings.embeds, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, + added_cond_kwargs=added_cond_kwargs, **kwargs, ) return unconditioned_next_x, conditioned_next_x - # TODO: looks unused - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): - assert isinstance(conditioning, dict) - assert isinstance(unconditioning, dict) - x_twice = torch.cat([x] * 2) - sigma_twice = torch.cat([sigma] * 2) - both_conditionings = dict() - for k in conditioning: - if isinstance(conditioning[k], list): - both_conditionings[k] = [ - torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k])) - ] - else: - both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) - unconditioned_next_x, conditioned_next_x = self.model_forward_callback( - x_twice, - sigma_twice, - both_conditionings, - **kwargs, - ).chunk(2) - return unconditioned_next_x, conditioned_next_x - def _apply_cross_attention_controlled_conditioning( self, x: torch.Tensor, @@ -391,26 +471,43 @@ class InvokeAIDiffuserComponent: mask=context.cross_attention_mask, cross_attention_types_to_do=[], ) + + added_cond_kwargs = None + is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" + if is_sdxl: + added_cond_kwargs = { + "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, + "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + } + # no cross attention for unconditioning (negative prompt) unconditioned_next_x = self.model_forward_callback( x, sigma, - unconditioning, + conditioning_data.unconditioned_embeddings.embeds, {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, + added_cond_kwargs=added_cond_kwargs, **kwargs, ) + if is_sdxl: + added_cond_kwargs = { + "text_embeds": conditioning_data.text_embeddings.pooled_embeds, + "time_ids": conditioning_data.text_embeddings.add_time_ids, + } + # do requested cross attention types for conditioning (positive prompt) cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning, + conditioning_data.text_embeddings.embeds, {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, + added_cond_kwargs=added_cond_kwargs, **kwargs, ) return unconditioned_next_x, conditioned_next_x @@ -564,7 +661,7 @@ class InvokeAIDiffuserComponent: # below is fugly omg conditionings = [uc] + [c for c, weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list] - chunk_count = ceil(len(conditionings) / 2) + chunk_count = math.ceil(len(conditionings) / 2) deltas = None for chunk_index in range(chunk_count): offset = chunk_index * 2