pass step count and step index to diffusion step func (#2342)

This commit is contained in:
Damian Stewart 2023-01-16 20:56:54 +01:00 committed by GitHub
parent 3a1724652e
commit 563196bd03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View File

@ -391,7 +391,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data, step_output = self.step(batched_t, latents, conditioning_data,
i, additional_guidance=additional_guidance) step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, 'pred_original_sample', None)
@ -410,7 +412,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, def step(self, t: torch.Tensor, latents: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
step_index:int | None = None, additional_guidance: List[Callable] = None): step_index:int, total_step_count:int,
additional_guidance: List[Callable] = None):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -427,6 +430,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings, conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count,
threshold=conditioning_data.threshold threshold=conditioning_data.threshold
) )

View File

@ -89,6 +89,7 @@ class InvokeAIDiffuserComponent:
conditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int]=None, step_index: Optional[int]=None,
total_step_count: Optional[int]=None,
threshold: Optional[ThresholdSettings]=None, threshold: Optional[ThresholdSettings]=None,
): ):
""" """
@ -106,7 +107,15 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma) if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)