This commit is contained in:
Sergey Borisov 2023-08-06 05:05:25 +03:00
parent dc96a3e79d
commit 9aaf67c5b4
4 changed files with 232 additions and 257 deletions

View File

@ -37,6 +37,10 @@ class BasicConditioningInfo:
# weight: float # weight: float
# mode: ConditioningAlgo # mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass @dataclass
class SDXLConditioningInfo(BasicConditioningInfo): class SDXLConditioningInfo(BasicConditioningInfo):
@ -44,6 +48,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")] ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]

View File

@ -174,11 +174,11 @@ class TextToLatentsInvocation(BaseInvocation):
unet, unet,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,

View File

@ -212,8 +212,8 @@ class ControlNetData:
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: torch.Tensor unconditioned_embeddings: Any # TODO: type
text_embeddings: torch.Tensor text_embeddings: Any # TODO: type
guidance_scale: Union[float, List[float]] guidance_scale: Union[float, List[float]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
@ -392,48 +392,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
submodels.append(value) submodels.append(value)
return submodels return submodels
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents,
num_inference_steps,
conditioning_data,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings( def latents_from_embeddings(
self, self,
latents: torch.Tensor, latents: torch.Tensor,
@ -492,13 +450,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
): ):
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = torch.full( batched_t = torch.full(
(batch_size,), (batch_size,),
@ -506,8 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
dtype=timesteps.dtype, dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
) )
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps) # print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
@ -569,95 +528,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward? # TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
unet_latent_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None: if control_data is not None:
# control_data should be type List[ControlNetData] controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
# this loop covers both ControlNet (one ControlNetData in list) control_data=control_data,
# and MultiControlNet (multiple ControlNetData in list) sample=latent_model_input,
for i, control_datum in enumerate(control_data): timestep=timestep,
control_mode = control_datum.control_mode step_index=step_index,
# soft_injection and cfg_injection are the two ControlNet control_mode booleans total_step_count=total_step_count,
# that are combined at higher level to make control_mode enum conditioning_data=conditioning_data,
# soft_injection determines whether to do per-layer re-weighting adjustment (if True) )
# or default weighting (if False)
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) sample=latent_model_input,
# only apply controlnet if current step is within the controlnet's begin/end step range timestep=t, # TODO: debug how handled batched and non batched timesteps
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else:
(
encoder_hidden_states,
encoder_attention_mask,
) = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s) conditioning_data=conditioning_data,
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
# extra:
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
)
guidance_scale = conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine(
uc_noise_pred,
c_noise_pred,
guidance_scale,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -738,41 +642,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback, callback,
) )
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
@ -877,7 +746,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[], nsfw_content_detected=[],
attention_map_saver=result_attention_maps, attention_map_saver=result_attention_maps,
) )
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=self.unet.dtype)
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)

View File

@ -1,6 +1,6 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil import math
from typing import Any, Callable, Dict, Optional, Union, List from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np import numpy as np
@ -127,33 +127,119 @@ class InvokeAIDiffuserComponent:
for _, module in tokens_cross_attention_modules: for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None) module.set_attention_slice_calculated_callback(None)
def do_diffusion_step( def do_controlnet_step(
self, self,
x: torch.Tensor, control_data,
sigma: torch.Tensor, sample: torch.Tensor,
unconditioning: Union[torch.Tensor, dict], timestep: torch.Tensor,
conditioning: Union[torch.Tensor, dict], step_index: int,
# unconditional_guidance_scale: float, total_step_count: int,
unconditional_guidance_scale: Union[float, List[float]], conditioning_data,
step_index: Optional[int] = None, ):
total_step_count: Optional[int] = None, down_block_res_samples, mid_block_res_sample = None, None
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
sample_model_input = sample
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
sample_model_input = torch.cat([sample] * 2)
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=sample_model_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def do_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data, # TODO: type
step_index: int,
total_step_count: int,
**kwargs, **kwargs,
): ):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
@ -163,25 +249,15 @@ class InvokeAIDiffuserComponent:
) )
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning: if wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
elif wants_cross_attention_control:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning( ) = self._apply_cross_attention_controlled_conditioning(
x, sample,
sigma, timestep,
unconditioning, conditioning_data,
conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs, **kwargs,
) )
@ -190,10 +266,9 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x, sample,
sigma, timestep,
unconditioning, conditioning_data,
conditioning,
**kwargs, **kwargs,
) )
@ -202,21 +277,13 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x, sample,
sigma, timestep,
unconditioning, conditioning_data,
conditioning,
**kwargs, **kwargs,
) )
combined_next_x = self._combine( return unconditioned_next_x, conditioned_next_x
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x,
conditioned_next_x,
guidance_scale,
)
return combined_next_x
def do_latent_postprocessing( def do_latent_postprocessing(
self, self,
@ -281,17 +348,35 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning) added_cond_kwargs = None
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], dim=0),
"time_ids": torch.cat([
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
], dim=0),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds
)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
sigma_twice, sigma_twice,
both_conditionings, both_conditionings,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
@ -320,46 +405,41 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
unconditioning, conditioning_data.unconditioned_embeddings.embeds,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning, conditioning_data.text_embeddings.embeds,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict()
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice,
sigma_twice,
both_conditionings,
**kwargs,
).chunk(2)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning( def _apply_cross_attention_controlled_conditioning(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -391,26 +471,43 @@ class InvokeAIDiffuserComponent:
mask=context.cross_attention_mask, mask=context.cross_attention_mask,
cross_attention_types_to_do=[], cross_attention_types_to_do=[],
) )
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# no cross attention for unconditioning (negative prompt) # no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
unconditioning, conditioning_data.unconditioned_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning, conditioning_data.text_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -564,7 +661,7 @@ class InvokeAIDiffuserComponent:
# below is fugly omg # below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list] conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2) chunk_count = math.ceil(len(conditionings) / 2)
deltas = None deltas = None
for chunk_index in range(chunk_count): for chunk_index in range(chunk_count):
offset = chunk_index * 2 offset = chunk_index * 2