Use image-space tile dimensions on the TiledMultiDiffusionDenoiseLatents invocation. This is more natural for many users.

This commit is contained in:
Ryan Dick 2024-06-25 10:10:50 -04:00 committed by Kent Keirsey
parent 06f49a30f6
commit b9946e50f9

View File

@ -85,13 +85,19 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(default=64, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=64, gt=0, description="Width of the tiles in latent space.")
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=16,
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The overlap between adjacent tiles in latent space. Tiles will be cropped during merging "
"(if necessary) to ensure that they overlap by exactly this amount.",
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
@ -159,6 +165,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
@ -166,9 +177,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_overlap,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
@ -207,8 +218,8 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=self.tile_height,
latent_width=self.tile_width,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
@ -253,7 +264,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=self.tile_overlap,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,