Fix error on zero timesteps

This commit is contained in:
Sergey Borisov 2023-08-14 00:20:01 +03:00
parent 7a8f14d595
commit 096333ba3f

View File

@ -385,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
orig_latents = latents.clone() orig_latents = latents.clone()
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = init_timestep.repeat(batch_size) batched_t = init_timestep.expand(batch_size)
if noise is not None: if noise is not None:
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
@ -448,20 +448,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
batch_size = latents.shape[0]
attention_map_saver: Optional[AttentionMapSaver] = None
if timesteps.shape[0] == 0:
return latents, attention_map_saver
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context( with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
): ):
batch_size = latents.shape[0]
batched_t = torch.full(
(batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self.unet.device,
)
yield PipelineIntermediateState( yield PipelineIntermediateState(
step=-1, step=-1,
order=self.scheduler.order, order=self.scheduler.order,
@ -470,10 +469,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents=latents, latents=latents,
) )
attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps) # print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t = t.expand(batch_size)
step_output = self.step( step_output = self.step(
batched_t, batched_t,
latents, latents,