mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function.
This commit is contained in:
parent
80a67572f1
commit
f604575862
@ -325,17 +325,71 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
|
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
latents = self.generate_latents_from_embeddings(
|
self._adjust_memory_efficient_attention(latents)
|
||||||
latents,
|
|
||||||
timesteps,
|
batch_size = latents.shape[0]
|
||||||
conditioning_data,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
if timesteps.shape[0] == 0:
|
||||||
mask_guidance=mask_guidance,
|
return latents
|
||||||
control_data=control_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
use_regional_prompting = (
|
||||||
callback=callback,
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
|
unet_attention_patcher = None
|
||||||
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
|
if use_ip_adapter or use_regional_prompting:
|
||||||
|
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||||
|
[
|
||||||
|
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks}
|
||||||
|
for ipa in ip_adapter_data
|
||||||
|
]
|
||||||
|
if use_ip_adapter
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
|
with attn_ctx:
|
||||||
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
batched_t = t.expand(batch_size)
|
||||||
|
step_output = self.step(
|
||||||
|
batched_t,
|
||||||
|
latents,
|
||||||
|
conditioning_data,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
mask_guidance=mask_guidance,
|
||||||
|
control_data=control_data,
|
||||||
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
|
)
|
||||||
|
latents = step_output.prev_sample
|
||||||
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=i,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=int(t),
|
||||||
|
latents=latents,
|
||||||
|
predicted_original=predicted_original,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||||
|
|
||||||
@ -351,82 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
conditioning_data: TextConditioningData,
|
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
|
||||||
callback: Callable[[PipelineIntermediateState], None],
|
|
||||||
mask_guidance: AddsMaskGuidance | None = None,
|
|
||||||
control_data: list[ControlNetData] | None = None,
|
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
|
|
||||||
if timesteps.shape[0] == 0:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
|
||||||
use_regional_prompting = (
|
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
|
||||||
)
|
|
||||||
unet_attention_patcher = None
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
|
||||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
|
||||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
|
||||||
if use_ip_adapter
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
|
|
||||||
with attn_ctx:
|
|
||||||
callback(
|
|
||||||
PipelineIntermediateState(
|
|
||||||
step=-1,
|
|
||||||
order=self.scheduler.order,
|
|
||||||
total_steps=len(timesteps),
|
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
|
||||||
latents=latents,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
|
||||||
batched_t = t.expand(batch_size)
|
|
||||||
step_output = self.step(
|
|
||||||
batched_t,
|
|
||||||
latents,
|
|
||||||
conditioning_data,
|
|
||||||
step_index=i,
|
|
||||||
total_step_count=len(timesteps),
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
mask_guidance=mask_guidance,
|
|
||||||
control_data=control_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
)
|
|
||||||
latents = step_output.prev_sample
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
|
||||||
|
|
||||||
callback(
|
|
||||||
PipelineIntermediateState(
|
|
||||||
step=i,
|
|
||||||
order=self.scheduler.order,
|
|
||||||
total_steps=len(timesteps),
|
|
||||||
timestep=int(t),
|
|
||||||
latents=latents,
|
|
||||||
predicted_original=predicted_original,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user