mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
dc96a3e79d
commit
9aaf67c5b4
@ -37,6 +37,10 @@ class BasicConditioningInfo:
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
@ -44,6 +48,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||
|
||||
|
@ -174,11 +174,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
|
@ -212,8 +212,8 @@ class ControlNetData:
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
unconditioned_embeddings: Any # TODO: type
|
||||
text_embeddings: Any # TODO: type
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
@ -392,48 +392,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
submodels.append(value)
|
||||
return submodels
|
||||
|
||||
def image_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
:param conditioning_data:
|
||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||
image generation. Can be used to tweak the same generation with different prompts.
|
||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||
:param callback:
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_map_saver,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
@ -492,13 +450,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=-1,
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full(
|
||||
(batch_size,),
|
||||
@ -506,8 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
dtype=timesteps.dtype,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=-1,
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@ -569,95 +528,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
|
||||
if control_data is not None:
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||
# that are combined at higher level to make control_mode enum
|
||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||
# or default weighting (if False)
|
||||
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||
# or the default both conditional and unconditional (if False)
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
|
||||
control_data=control_data,
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
)
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
if cfg_injection:
|
||||
control_latent_input = unet_latent_input
|
||||
else:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
encoder_hidden_states = conditioning_data.text_embeddings
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=control_latent_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
if cfg_injection:
|
||||
# Inferred ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# prepend zeros for unconditional batch
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
x=unet_latent_input,
|
||||
sigma=t,
|
||||
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||
conditioning=conditioning_data.text_embeddings,
|
||||
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
conditioning_data=conditioning_data,
|
||||
|
||||
# extra:
|
||||
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
guidance_scale = conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[step_index]
|
||||
|
||||
noise_pred = self.invokeai_diffuser._combine(
|
||||
uc_noise_pred,
|
||||
c_noise_pred,
|
||||
guidance_scale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@ -738,41 +642,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
callback,
|
||||
)
|
||||
|
||||
def img2img_from_latents_and_embeddings(
|
||||
self,
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents
|
||||
if strength < 1.0
|
||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
@ -877,7 +746,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
return self.check_for_safety(output, dtype=self.unet.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
@ -127,33 +127,119 @@ class InvokeAIDiffuserComponent:
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
control_data,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
conditioning_data,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||
# that are combined at higher level to make control_mode enum
|
||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||
# or default weighting (if False)
|
||||
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||
# or the default both conditional and unconditional (if False)
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
if cfg_injection:
|
||||
sample_model_input = sample
|
||||
else:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
sample_model_input = torch.cat([sample] * 2)
|
||||
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat([
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
], dim=0),
|
||||
"time_ids": torch.cat([
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
], dim=0),
|
||||
}
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=sample_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
if cfg_injection:
|
||||
# Inferred ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# prepend zeros for unconditional batch
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def do_unet_step(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data, # TODO: type
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
if isinstance(unconditional_guidance_scale, list):
|
||||
guidance_scale = unconditional_guidance_scale[step_index]
|
||||
else:
|
||||
guidance_scale = unconditional_guidance_scale
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
@ -163,25 +249,15 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
**kwargs,
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
if wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
@ -190,10 +266,9 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -202,21 +277,13 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
guidance_scale,
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
@ -281,17 +348,35 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat([
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
], dim=0),
|
||||
"time_ids": torch.cat([
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
], dim=0),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
@ -320,46 +405,41 @@ class InvokeAIDiffuserComponent:
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = dict()
|
||||
for k in conditioning:
|
||||
if isinstance(conditioning[k], list):
|
||||
both_conditionings[k] = [
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
**kwargs,
|
||||
).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -391,26 +471,43 @@ class InvokeAIDiffuserComponent:
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -564,7 +661,7 @@ class InvokeAIDiffuserComponent:
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings) / 2)
|
||||
chunk_count = math.ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index * 2
|
||||
|
Loading…
Reference in New Issue
Block a user