mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
dc96a3e79d
commit
9aaf67c5b4
@ -37,6 +37,10 @@ class BasicConditioningInfo:
|
|||||||
# weight: float
|
# weight: float
|
||||||
# mode: ConditioningAlgo
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
@ -44,6 +48,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||||
|
|
||||||
|
@ -174,11 +174,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
extra_conditioning_info = c.extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
|
@ -212,8 +212,8 @@ class ControlNetData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: torch.Tensor
|
unconditioned_embeddings: Any # TODO: type
|
||||||
text_embeddings: torch.Tensor
|
text_embeddings: Any # TODO: type
|
||||||
guidance_scale: Union[float, List[float]]
|
guidance_scale: Union[float, List[float]]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
@ -392,48 +392,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
submodels.append(value)
|
submodels.append(value)
|
||||||
return submodels
|
return submodels
|
||||||
|
|
||||||
def image_from_embeddings(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
num_inference_steps: int,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
*,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
run_id=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
r"""
|
|
||||||
Function invoked when calling the pipeline for generation.
|
|
||||||
|
|
||||||
:param conditioning_data:
|
|
||||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
|
||||||
image generation. Can be used to tweak the same generation with different prompts.
|
|
||||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
|
||||||
image at the expense of slower inference.
|
|
||||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
|
||||||
:param callback:
|
|
||||||
:param run_id:
|
|
||||||
"""
|
|
||||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
|
||||||
latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data,
|
|
||||||
noise=noise,
|
|
||||||
run_id=run_id,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.decode_latents(result_latents)
|
|
||||||
output = InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
images=image,
|
|
||||||
nsfw_content_detected=[],
|
|
||||||
attention_map_saver=result_attention_map_saver,
|
|
||||||
)
|
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
@ -492,13 +450,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
yield PipelineIntermediateState(
|
|
||||||
run_id=run_id,
|
|
||||||
step=-1,
|
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
|
||||||
latents=latents,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = torch.full(
|
batched_t = torch.full(
|
||||||
(batch_size,),
|
(batch_size,),
|
||||||
@ -506,8 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
dtype=timesteps.dtype,
|
dtype=timesteps.dtype,
|
||||||
device=self._model_group.device_for(self.unet),
|
device=self._model_group.device_for(self.unet),
|
||||||
)
|
)
|
||||||
|
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
|
yield PipelineIntermediateState(
|
||||||
|
run_id=run_id,
|
||||||
|
step=-1,
|
||||||
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
# print("timesteps:", timesteps)
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
@ -569,95 +528,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
# default is no controlnet, so set controlnet processing output to None
|
# default is no controlnet, so set controlnet processing output to None
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
# control_data should be type List[ControlNetData]
|
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
control_data=control_data,
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
sample=latent_model_input,
|
||||||
for i, control_datum in enumerate(control_data):
|
|
||||||
control_mode = control_datum.control_mode
|
|
||||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
|
||||||
# that are combined at higher level to make control_mode enum
|
|
||||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
|
||||||
# or default weighting (if False)
|
|
||||||
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
|
||||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
|
||||||
# or the default both conditional and unconditional (if False)
|
|
||||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
|
||||||
|
|
||||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
|
||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
|
||||||
if cfg_injection:
|
|
||||||
control_latent_input = unet_latent_input
|
|
||||||
else:
|
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
|
||||||
# classifier_free_guidance is <= 1.0 ?)
|
|
||||||
control_latent_input = torch.cat([unet_latent_input] * 2)
|
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings
|
|
||||||
encoder_attention_mask = None
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
) = self.invokeai_diffuser._concat_conditionings_for_batch(
|
|
||||||
conditioning_data.unconditioned_embeddings,
|
|
||||||
conditioning_data.text_embeddings,
|
|
||||||
)
|
|
||||||
if isinstance(control_datum.weight, list):
|
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
|
||||||
else:
|
|
||||||
# if controlnet has a single weight, use it for all steps
|
|
||||||
controlnet_weight = control_datum.weight
|
|
||||||
|
|
||||||
# controlnet(s) inference
|
|
||||||
down_samples, mid_sample = control_datum.model(
|
|
||||||
sample=control_latent_input,
|
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
controlnet_cond=control_datum.image_tensor,
|
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
|
||||||
return_dict=False,
|
|
||||||
)
|
|
||||||
if cfg_injection:
|
|
||||||
# Inferred ControlNet only for the conditional batch.
|
|
||||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
||||||
# prepend zeros for unconditional batch
|
|
||||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
|
||||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
|
||||||
|
|
||||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
|
||||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
||||||
else:
|
|
||||||
# add controlnet outputs together if have multiple controlnets
|
|
||||||
down_block_res_samples = [
|
|
||||||
samples_prev + samples_curr
|
|
||||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
|
||||||
]
|
|
||||||
mid_block_res_sample += mid_sample
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
|
||||||
x=unet_latent_input,
|
|
||||||
sigma=t,
|
|
||||||
unconditioning=conditioning_data.unconditioned_embeddings,
|
|
||||||
conditioning=conditioning_data.text_embeddings,
|
|
||||||
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
conditioning_data=conditioning_data,
|
||||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
)
|
||||||
|
|
||||||
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
|
sample=latent_model_input,
|
||||||
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||||
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
|
||||||
|
# extra:
|
||||||
|
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
|
||||||
|
)
|
||||||
|
|
||||||
|
guidance_scale = conditioning_data.guidance_scale
|
||||||
|
if isinstance(guidance_scale, list):
|
||||||
|
guidance_scale = guidance_scale[step_index]
|
||||||
|
|
||||||
|
noise_pred = self.invokeai_diffuser._combine(
|
||||||
|
uc_noise_pred,
|
||||||
|
c_noise_pred,
|
||||||
|
guidance_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -738,41 +642,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
def img2img_from_latents_and_embeddings(
|
|
||||||
self,
|
|
||||||
initial_latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
strength,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
run_id=None,
|
|
||||||
callback=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
|
||||||
latents=initial_latents
|
|
||||||
if strength < 1.0
|
|
||||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
timesteps=timesteps,
|
|
||||||
noise=noise,
|
|
||||||
run_id=run_id,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.decode_latents(result_latents)
|
|
||||||
output = InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
images=image,
|
|
||||||
nsfw_content_detected=[],
|
|
||||||
attention_map_saver=result_attention_maps,
|
|
||||||
)
|
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
|
||||||
|
|
||||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
@ -877,7 +746,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
nsfw_content_detected=[],
|
nsfw_content_detected=[],
|
||||||
attention_map_saver=result_attention_maps,
|
attention_map_saver=result_attention_maps,
|
||||||
)
|
)
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
return self.check_for_safety(output, dtype=self.unet.dtype)
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||||
init_image = init_image.to(device=device, dtype=dtype)
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
import math
|
||||||
from typing import Any, Callable, Dict, Optional, Union, List
|
from typing import Any, Callable, Dict, Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -127,33 +127,119 @@ class InvokeAIDiffuserComponent:
|
|||||||
for _, module in tokens_cross_attention_modules:
|
for _, module in tokens_cross_attention_modules:
|
||||||
module.set_attention_slice_calculated_callback(None)
|
module.set_attention_slice_calculated_callback(None)
|
||||||
|
|
||||||
def do_diffusion_step(
|
def do_controlnet_step(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
control_data,
|
||||||
sigma: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor, dict],
|
timestep: torch.Tensor,
|
||||||
conditioning: Union[torch.Tensor, dict],
|
step_index: int,
|
||||||
# unconditional_guidance_scale: float,
|
total_step_count: int,
|
||||||
unconditional_guidance_scale: Union[float, List[float]],
|
conditioning_data,
|
||||||
step_index: Optional[int] = None,
|
):
|
||||||
total_step_count: Optional[int] = None,
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
# control_data should be type List[ControlNetData]
|
||||||
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
|
for i, control_datum in enumerate(control_data):
|
||||||
|
control_mode = control_datum.control_mode
|
||||||
|
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||||
|
# that are combined at higher level to make control_mode enum
|
||||||
|
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||||
|
# or default weighting (if False)
|
||||||
|
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
||||||
|
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||||
|
# or the default both conditional and unconditional (if False)
|
||||||
|
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||||
|
|
||||||
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
if cfg_injection:
|
||||||
|
sample_model_input = sample
|
||||||
|
else:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
sample_model_input = torch.cat([sample] * 2)
|
||||||
|
|
||||||
|
added_cond_kwargs = None
|
||||||
|
|
||||||
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
|
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||||
|
}
|
||||||
|
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||||
|
encoder_attention_mask = None
|
||||||
|
else:
|
||||||
|
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": torch.cat([
|
||||||
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
|
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
|
conditioning_data.text_embeddings.pooled_embeds,
|
||||||
|
], dim=0),
|
||||||
|
"time_ids": torch.cat([
|
||||||
|
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
|
conditioning_data.text_embeddings.add_time_ids,
|
||||||
|
], dim=0),
|
||||||
|
}
|
||||||
|
(
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self._concat_conditionings_for_batch(
|
||||||
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
|
conditioning_data.text_embeddings.embeds,
|
||||||
|
)
|
||||||
|
if isinstance(control_datum.weight, list):
|
||||||
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
|
else:
|
||||||
|
# if controlnet has a single weight, use it for all steps
|
||||||
|
controlnet_weight = control_datum.weight
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
|
down_samples, mid_sample = control_datum.model(
|
||||||
|
sample=sample_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
if cfg_injection:
|
||||||
|
# Inferred ControlNet only for the conditional batch.
|
||||||
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||||
|
# prepend zeros for unconditional batch
|
||||||
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
# add controlnet outputs together if have multiple controlnets
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
|
return down_block_res_samples, mid_block_res_sample
|
||||||
|
|
||||||
|
def do_unet_step(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
conditioning_data, # TODO: type
|
||||||
|
step_index: int,
|
||||||
|
total_step_count: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
:param x: current latents
|
|
||||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
|
||||||
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
|
||||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
|
||||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
|
||||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
|
||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(unconditional_guidance_scale, list):
|
|
||||||
guidance_scale = unconditional_guidance_scale[step_index]
|
|
||||||
else:
|
|
||||||
guidance_scale = unconditional_guidance_scale
|
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
@ -163,25 +249,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
|
||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_cross_attention_control:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif wants_cross_attention_control:
|
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_cross_attention_controlled_conditioning(
|
) = self._apply_cross_attention_controlled_conditioning(
|
||||||
x,
|
sample,
|
||||||
sigma,
|
timestep,
|
||||||
unconditioning,
|
conditioning_data,
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -190,10 +266,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x,
|
sample,
|
||||||
sigma,
|
timestep,
|
||||||
unconditioning,
|
conditioning_data,
|
||||||
conditioning,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,21 +277,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x,
|
sample,
|
||||||
sigma,
|
timestep,
|
||||||
unconditioning,
|
conditioning_data,
|
||||||
conditioning,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
return unconditioned_next_x, conditioned_next_x
|
||||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
|
||||||
unconditioned_next_x,
|
|
||||||
conditioned_next_x,
|
|
||||||
guidance_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_next_x
|
|
||||||
|
|
||||||
def do_latent_postprocessing(
|
def do_latent_postprocessing(
|
||||||
self,
|
self,
|
||||||
@ -281,17 +348,35 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
|
added_cond_kwargs = None
|
||||||
|
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": torch.cat([
|
||||||
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
|
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
|
conditioning_data.text_embeddings.pooled_embeds,
|
||||||
|
], dim=0),
|
||||||
|
"time_ids": torch.cat([
|
||||||
|
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
|
conditioning_data.text_embeddings.add_time_ids,
|
||||||
|
], dim=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
|
conditioning_data.text_embeddings.embeds
|
||||||
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
both_conditionings,
|
both_conditionings,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
@ -320,46 +405,41 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
|
added_cond_kwargs = None
|
||||||
|
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||||
|
if is_sdxl:
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
|
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning_data.text_embeddings.embeds,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
# TODO: looks unused
|
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
|
||||||
assert isinstance(conditioning, dict)
|
|
||||||
assert isinstance(unconditioning, dict)
|
|
||||||
x_twice = torch.cat([x] * 2)
|
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
|
||||||
both_conditionings = dict()
|
|
||||||
for k in conditioning:
|
|
||||||
if isinstance(conditioning[k], list):
|
|
||||||
both_conditionings[k] = [
|
|
||||||
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
|
||||||
x_twice,
|
|
||||||
sigma_twice,
|
|
||||||
both_conditionings,
|
|
||||||
**kwargs,
|
|
||||||
).chunk(2)
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning(
|
def _apply_cross_attention_controlled_conditioning(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -391,26 +471,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
mask=context.cross_attention_mask,
|
mask=context.cross_attention_mask,
|
||||||
cross_attention_types_to_do=[],
|
cross_attention_types_to_do=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
added_cond_kwargs = None
|
||||||
|
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||||
|
if is_sdxl:
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
|
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
# no cross attention for unconditioning (negative prompt)
|
# no cross attention for unconditioning (negative prompt)
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning_data.text_embeddings.embeds,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
@ -564,7 +661,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||||
chunk_count = ceil(len(conditionings) / 2)
|
chunk_count = math.ceil(len(conditionings) / 2)
|
||||||
deltas = None
|
deltas = None
|
||||||
for chunk_index in range(chunk_count):
|
for chunk_index in range(chunk_count):
|
||||||
offset = chunk_index * 2
|
offset = chunk_index * 2
|
||||||
|
Loading…
Reference in New Issue
Block a user