Fix typing to reflect that the callback arg to latents_from_embeddings is never None.

This commit is contained in:
Ryan Dick 2024-06-12 14:41:01 -04:00 committed by Kent Keirsey
parent d661517d94
commit afaebdf151

View File

@ -289,7 +289,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
seed: int, seed: int,
timesteps: torch.Tensor, timesteps: torch.Tensor,
init_timestep: torch.Tensor, init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None],
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
@ -363,11 +363,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],
callback: Callable[[PipelineIntermediateState], None],
mask_guidance: AddsMaskGuidance | None = None, mask_guidance: AddsMaskGuidance | None = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor: ) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
@ -394,16 +394,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx: with attn_ctx:
if callback is not None: callback(
callback( PipelineIntermediateState(
PipelineIntermediateState( step=-1,
step=-1, order=self.scheduler.order,
order=self.scheduler.order, total_steps=len(timesteps),
total_steps=len(timesteps), timestep=self.scheduler.config.num_train_timesteps,
timestep=self.scheduler.config.num_train_timesteps, latents=latents,
latents=latents,
)
) )
)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size) batched_t = t.expand(batch_size)
@ -422,17 +421,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None) predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None: callback(
callback( PipelineIntermediateState(
PipelineIntermediateState( step=i,
step=i, order=self.scheduler.order,
order=self.scheduler.order, total_steps=len(timesteps),
total_steps=len(timesteps), timestep=int(t),
timestep=int(t), latents=latents,
latents=latents, predicted_original=predicted_original,
predicted_original=predicted_original,
)
) )
)
return latents return latents