Remove unused num_inference_steps.

This commit is contained in:
Ryan Dick 2024-06-12 13:39:34 -04:00 committed by Kent Keirsey
parent 230e205541
commit ffc28176fe
4 changed files with 17 additions and 46 deletions

View File

@ -601,7 +601,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_start: float, denoising_start: float,
denoising_end: float, denoising_end: float,
seed: int, seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]: ) -> Tuple[List[int], int, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin) assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu") scheduler.set_timesteps(steps, device="cpu")
@ -627,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
init_timestep = timesteps[t_start_idx : t_start_idx + 1] init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Any] = {} scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step) scheduler_step_signature = inspect.signature(scheduler.step)
@ -649,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if isinstance(scheduler, TCDScheduler): if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0}) scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask( def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor self, context: InvocationContext, latents: torch.Tensor
@ -803,7 +802,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype, dtype=unet.dtype,
) )
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
steps=self.steps, steps=self.steps,
@ -821,7 +820,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask=mask, mask=mask,
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask, gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,

View File

@ -228,7 +228,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
] ]
controlnet_data_tiles.append(tile_controlnet_data) controlnet_data_tiles.append(tile_controlnet_data)
# TODO(ryand): Logic from here down needs updating --------------------
# Denoise (i.e. "refine") each tile independently. # Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape assert latent_tile.shape == noise_tile.shape
@ -238,27 +237,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# the tiles. Ideally, the ControlNet code should be able to work with Tensors. # the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np) image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile. timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
DenoiseLatentsInvocation.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
steps=self.steps, steps=self.steps,
@ -266,7 +245,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
seed=seed, seed=seed,
) )
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM. # TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype) latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
@ -280,7 +258,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
mask=None, mask=None,
masked_latents=None, masked_latents=None,
gradient_mask=None, gradient_mask=None,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=[controlnet_data], control_data=[controlnet_data],

View File

@ -320,8 +320,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
resize_mode="just_resize_simple", resize_mode="just_resize_simple",
) )
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = ( timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
DenoiseLatentsInvocation.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
steps=self.steps, steps=self.steps,
@ -329,7 +328,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
seed=seed, seed=seed,
) )
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM. # TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype) latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
@ -343,7 +341,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
mask=None, mask=None,
masked_latents=None, masked_latents=None,
gradient_mask=None, gradient_mask=None,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=[controlnet_data], control_data=[controlnet_data],

View File

@ -283,7 +283,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def latents_from_embeddings( def latents_from_embeddings(
self, self,
latents: torch.Tensor, latents: torch.Tensor,
num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
*, *,