Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function.

This commit is contained in:
Ryan Dick 2024-06-13 12:24:43 -04:00 committed by Kent Keirsey
parent 80a67572f1
commit f604575862

View File

@ -325,17 +325,71 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
try:
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=mask_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
self._adjust_memory_efficient_attention(latents)
batch_size = latents.shape[0]
if timesteps.shape[0] == 0:
return latents
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks}
for ipa in ip_adapter_data
]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=mask_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
@ -351,82 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps: torch.Tensor,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
callback: Callable[[PipelineIntermediateState], None],
mask_guidance: AddsMaskGuidance | None = None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
batch_size = latents.shape[0]
if timesteps.shape[0] == 0:
return latents
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=mask_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
@torch.inference_mode()
def step(
self,