mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix typing to reflect that the callback arg to latents_from_embeddings is never None.
This commit is contained in:
parent
d661517d94
commit
afaebdf151
@ -289,7 +289,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
seed: int,
|
seed: int,
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
@ -363,11 +363,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
mask_guidance: AddsMaskGuidance | None = None,
|
mask_guidance: AddsMaskGuidance | None = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
@ -394,16 +394,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
callback(
|
||||||
callback(
|
PipelineIntermediateState(
|
||||||
PipelineIntermediateState(
|
step=-1,
|
||||||
step=-1,
|
order=self.scheduler.order,
|
||||||
order=self.scheduler.order,
|
total_steps=len(timesteps),
|
||||||
total_steps=len(timesteps),
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
latents=latents,
|
||||||
latents=latents,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
@ -422,17 +421,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
if callback is not None:
|
callback(
|
||||||
callback(
|
PipelineIntermediateState(
|
||||||
PipelineIntermediateState(
|
step=i,
|
||||||
step=i,
|
order=self.scheduler.order,
|
||||||
order=self.scheduler.order,
|
total_steps=len(timesteps),
|
||||||
total_steps=len(timesteps),
|
timestep=int(t),
|
||||||
timestep=int(t),
|
latents=latents,
|
||||||
latents=latents,
|
predicted_original=predicted_original,
|
||||||
predicted_original=predicted_original,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user