mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix error at high denoising_start, fix unipc(cpu_only)
This commit is contained in:
parent
7479f9cc02
commit
f3ae52ff97
@ -317,6 +317,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return control_data
|
return control_data
|
||||||
|
|
||||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||||
|
if scheduler.config.get("cpu_only", False):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
# apply denoising_start
|
# apply denoising_start
|
||||||
num_inference_steps = steps
|
num_inference_steps = steps
|
||||||
scheduler.set_timesteps(num_inference_steps, device=device)
|
scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
@ -325,6 +328,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||||
num_inference_steps = num_inference_steps - t_start
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
|
init_timestep = timesteps[:1]
|
||||||
|
|
||||||
# apply denoising_end
|
# apply denoising_end
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||||
|
|
||||||
@ -332,7 +337,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||||
|
|
||||||
return num_inference_steps, timesteps
|
return num_inference_steps, timesteps, init_timestep
|
||||||
|
|
||||||
def prep_mask_tensor(self, mask, context, lantents):
|
def prep_mask_tensor(self, mask, context, lantents):
|
||||||
if mask is None:
|
if mask is None:
|
||||||
@ -418,7 +423,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
@ -429,6 +434,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
@ -365,22 +365,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps=None,
|
timesteps: torch.Tensor,
|
||||||
|
init_timestep: torch.Tensor,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
# TODO:
|
if init_timestep.shape[0] == 0:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
return latents, None
|
||||||
scheduler_device = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
scheduler_device = self.unet.device
|
|
||||||
|
|
||||||
if timesteps is None:
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||||
@ -392,31 +386,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
orig_latents = latents.clone()
|
orig_latents = latents.clone()
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = torch.full(
|
batched_t = init_timestep.repeat(batch_size)
|
||||||
(batch_size,),
|
|
||||||
timesteps[0],
|
|
||||||
dtype=timesteps.dtype,
|
|
||||||
device=self.unet.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
else:
|
|
||||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
|
||||||
if mask is not None:
|
|
||||||
noise = torch.randn(
|
|
||||||
orig_latents.shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
||||||
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
|
||||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if is_inpainting_model(self.unet):
|
if is_inpainting_model(self.unet):
|
||||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||||
@ -428,6 +403,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self._unet_forward, mask, orig_latents
|
self._unet_forward, mask, orig_latents
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(
|
||||||
|
orig_latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||||
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
|
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||||
|
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user