feat(app): signal progress while processing spandrel tiles

This commit is contained in:
psychedelicious 2024-08-03 22:02:11 +10:00
parent 487815b181
commit 682280683a

View File

@ -61,6 +61,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
step_callback: Callable[[int, int], None],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
@ -103,7 +104,12 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
# Update progress, starting with 0.
step_callback(0, pbar.total)
for tile, scaled_tile in pbar:
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
@ -136,12 +142,27 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
:,
] = output_tile[top_overlap:, left_overlap:, :]
step_callback(pbar.n + 1, pbar.total)
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
return pil_image
def _get_step_callback(self, context: InvocationContext) -> Callable[[int, int], None]:
invocation_type = self.get_type()
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
name=invocation_type,
step=step,
total_steps=total_steps,
message="Processing image",
)
return step_callback
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
@ -156,7 +177,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context)
)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@ -202,7 +225,9 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context)
)
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
@ -213,7 +238,13 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
pil_image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
self._get_step_callback(context),
)
iterations += 1
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.