mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
pass step count and step index to diffusion step func (#2342)
This commit is contained in:
parent
3a1724652e
commit
563196bd03
@ -391,7 +391,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, conditioning_data,
|
step_output = self.step(batched_t, latents, conditioning_data,
|
||||||
i, additional_guidance=additional_guidance)
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
additional_guidance=additional_guidance)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
|
|
||||||
@ -410,7 +412,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
step_index:int | None = None, additional_guidance: List[Callable] = None):
|
step_index:int, total_step_count:int,
|
||||||
|
additional_guidance: List[Callable] = None):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
@ -427,6 +430,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
threshold=conditioning_data.threshold
|
threshold=conditioning_data.threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning: Union[torch.Tensor,dict],
|
conditioning: Union[torch.Tensor,dict],
|
||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int]=None,
|
step_index: Optional[int]=None,
|
||||||
|
total_step_count: Optional[int]=None,
|
||||||
threshold: Optional[ThresholdSettings]=None,
|
threshold: Optional[ThresholdSettings]=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -106,7 +107,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
if step_index is not None and total_step_count is not None:
|
||||||
|
# 🧨diffusers codepath
|
||||||
|
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||||
|
else:
|
||||||
|
# legacy compvis codepath
|
||||||
|
# TODO remove when compvis codepath support is dropped
|
||||||
|
if step_index is None and sigma is None:
|
||||||
|
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||||
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
|
|
||||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user