This commit is contained in:
Sergey Borisov 2023-08-06 05:05:25 +03:00
parent dc96a3e79d
commit 9aaf67c5b4
4 changed files with 232 additions and 257 deletions

View File

@ -37,6 +37,10 @@ class BasicConditioningInfo:
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
@ -44,6 +48,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]

View File

@ -174,11 +174,11 @@ class TextToLatentsInvocation(BaseInvocation):
unet,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,

View File

@ -212,8 +212,8 @@ class ControlNetData:
@dataclass
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
unconditioned_embeddings: Any # TODO: type
text_embeddings: Any # TODO: type
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
@ -392,48 +392,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
submodels.append(value)
return submodels
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents,
num_inference_steps,
conditioning_data,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings(
self,
latents: torch.Tensor,
@ -492,13 +450,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
batch_size = latents.shape[0]
batched_t = torch.full(
(batch_size,),
@ -506,8 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet),
)
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
@ -569,95 +528,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
sample=latent_model_input,
timestep=timestep,
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else:
(
encoder_hidden_states,
encoder_attention_mask,
) = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
)
guidance_scale = conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine(
uc_noise_pred,
c_noise_pred,
guidance_scale,
)
# compute the previous noisy sample x_t -> x_t-1
@ -738,41 +642,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback,
)
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
@ -877,7 +746,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
return self.check_for_safety(output, dtype=self.unet.dtype)
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype)

View File

@ -1,6 +1,6 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
import math
from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np
@ -127,33 +127,119 @@ class InvokeAIDiffuserComponent:
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(
def do_controlnet_step(
self,
x: torch.Tensor,
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
control_data,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data,
):
down_block_res_samples, mid_block_res_sample = None, None
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
sample_model_input = sample
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
sample_model_input = torch.cat([sample] * 2)
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=sample_model_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def do_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data, # TODO: type
step_index: int,
total_step_count: int,
**kwargs,
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@ -163,25 +249,15 @@ class InvokeAIDiffuserComponent:
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
elif wants_cross_attention_control:
if wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
sample,
timestep,
conditioning_data,
cross_attention_control_types_to_do,
**kwargs,
)
@ -190,10 +266,9 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x,
sigma,
unconditioning,
conditioning,
sample,
timestep,
conditioning_data,
**kwargs,
)
@ -202,21 +277,13 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x,
sigma,
unconditioning,
conditioning,
sample,
timestep,
conditioning_data,
**kwargs,
)
combined_next_x = self._combine(
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x,
conditioned_next_x,
guidance_scale,
)
return combined_next_x
return unconditioned_next_x, conditioned_next_x
def do_latent_postprocessing(
self,
@ -281,17 +348,35 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
added_cond_kwargs = None
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds
)
both_results = self.model_forward_callback(
x_twice,
sigma_twice,
both_conditionings,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
@ -320,46 +405,41 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
unconditioning,
conditioning_data.unconditioned_embeddings.embeds,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning,
conditioning_data.text_embeddings.embeds,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
# TODO: looks unused
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict()
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice,
sigma_twice,
both_conditionings,
**kwargs,
).chunk(2)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
@ -391,26 +471,43 @@ class InvokeAIDiffuserComponent:
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
unconditioning,
conditioning_data.unconditioned_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning,
conditioning_data.text_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
@ -564,7 +661,7 @@ class InvokeAIDiffuserComponent:
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2)
chunk_count = math.ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index * 2