mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function.
This commit is contained in:
parent
80a67572f1
commit
f604575862
@ -325,44 +325,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
|
||||
|
||||
try:
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=mask_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if is_gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
mask_guidance: AddsMaskGuidance | None = None,
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
@ -379,7 +341,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
[
|
||||
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks}
|
||||
for ipa in ip_adapter_data
|
||||
]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
@ -425,6 +390,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if is_gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
|
Loading…
Reference in New Issue
Block a user