mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
pass step count and step index to diffusion step func (#2342)
This commit is contained in:
parent
3a1724652e
commit
563196bd03
@ -391,7 +391,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
i, additional_guidance=additional_guidance)
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
@ -410,7 +412,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index:int | None = None, additional_guidance: List[Callable] = None):
|
||||
step_index:int, total_step_count:int,
|
||||
additional_guidance: List[Callable] = None):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
@ -427,6 +430,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
threshold=conditioning_data.threshold
|
||||
)
|
||||
|
||||
|
@ -89,6 +89,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None,
|
||||
total_step_count: Optional[int]=None,
|
||||
threshold: Optional[ThresholdSettings]=None,
|
||||
):
|
||||
"""
|
||||
@ -106,7 +107,15 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
|
Loading…
Reference in New Issue
Block a user