Remove scheduler_args from ConditioningData structure.

This commit is contained in:
Ryan Dick 2024-02-28 12:15:39 -05:00
parent bf3ee1fefa
commit ef9e0c969b
3 changed files with 23 additions and 36 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import math import math
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
@ -368,9 +369,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data( def get_conditioning_data(
self, self,
context: InvocationContext, context: InvocationContext,
scheduler: Scheduler,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
seed: int,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -385,14 +384,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
# flip all bits to have noise different from initial
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
)
return conditioning_data return conditioning_data
def create_pipeline( def create_pipeline(
@ -636,6 +627,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
steps: int, steps: int,
denoising_start: float, denoising_start: float,
denoising_end: float, denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int]: ) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin) assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
@ -664,7 +656,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order num_inference_steps = len(timesteps) // scheduler.order
return num_inference_steps, timesteps, init_timestep scheduler_step_kwargs = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask( def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor self, context: InvocationContext, latents: torch.Tensor
@ -758,7 +758,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) conditioning_data = self.get_conditioning_data(context, unet)
controlnet_data = self.prep_control_data( controlnet_data = self.prep_control_data(
context=context, context=context,
@ -776,12 +776,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack=exit_stack, exit_stack=exit_stack,
) )
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
steps=self.steps, steps=self.steps,
denoising_start=self.denoising_start, denoising_start=self.denoising_start,
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
seed=seed,
) )
result_latents = pipeline.latents_from_embeddings( result_latents = pipeline.latents_from_embeddings(
@ -794,6 +795,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask, gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,

View File

@ -295,6 +295,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
latents: torch.Tensor, latents: torch.Tensor,
num_inference_steps: int, num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
noise: Optional[torch.Tensor], noise: Optional[torch.Tensor],
@ -355,6 +356,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents, latents,
timesteps, timesteps,
conditioning_data, conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
@ -381,6 +383,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents: torch.Tensor, latents: torch.Tensor,
timesteps, timesteps,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
scheduler_step_kwargs: dict[str, Any],
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
@ -435,6 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data, conditioning_data,
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
@ -466,6 +470,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
@ -569,7 +574,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again. # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance: for guidance in additional_guidance:

View File

@ -1,7 +1,5 @@
import dataclasses from dataclasses import dataclass
import inspect from typing import List, Optional, Union
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch import torch
@ -71,23 +69,5 @@ class ConditioningData:
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
""" """
guidance_rescale_multiplier: float = 0 guidance_rescale_multiplier: float = 0
scheduler_args: dict[str, Any] = field(default_factory=dict)
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)