pass step count and step index to diffusion step func (#2342)

This commit is contained in:
Damian Stewart 2023-01-16 20:56:54 +01:00 committed by GitHub
parent 3a1724652e
commit 563196bd03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View File

@ -391,7 +391,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
i, additional_guidance=additional_guidance)
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
@ -410,7 +412,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor,
conditioning_data: ConditioningData,
step_index:int | None = None, additional_guidance: List[Callable] = None):
step_index:int, total_step_count:int,
additional_guidance: List[Callable] = None):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -427,6 +430,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
threshold=conditioning_data.threshold
)

View File

@ -89,6 +89,7 @@ class InvokeAIDiffuserComponent:
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None,
total_step_count: Optional[int]=None,
threshold: Optional[ThresholdSettings]=None,
):
"""
@ -106,7 +107,15 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)