Fix error at high denoising_start, fix unipc(cpu_only)

This commit is contained in:
Sergey Borisov
2023-08-11 15:46:16 +03:00
parent 7479f9cc02
commit f3ae52ff97
2 changed files with 25 additions and 32 deletions

View File

@ -365,22 +365,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData,
*,
noise: Optional[torch.Tensor],
timesteps=None,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
mask: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
# TODO:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu")
else:
scheduler_device = self.unet.device
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps = self.scheduler.timesteps
if init_timestep.shape[0] == 0:
return latents, None
infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
@ -392,31 +386,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_t = torch.full(
(batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self.unet.device,
)
batched_t = init_timestep.repeat(batch_size)
if noise is not None:
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if mask is not None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
if mask is not None:
if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
@ -428,6 +403,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._unet_forward, mask, orig_latents
)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
try: