diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index ae4f48ef77..74b29fbc8c 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -61,6 +61,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): tile_size: int, spandrel_model: SpandrelImageToImageModel, is_canceled: Callable[[], bool], + step_callback: Callable[[int, int], None], ) -> Image.Image: # Compute the image tiles. if tile_size > 0: @@ -103,7 +104,12 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) # Run the model on each tile. - for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): + pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles") + + # Update progress, starting with 0. + step_callback(0, pbar.total) + + for tile, scaled_tile in pbar: # Exit early if the invocation has been canceled. if is_canceled(): raise CanceledException @@ -136,12 +142,27 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): :, ] = output_tile[top_overlap:, left_overlap:, :] + step_callback(pbar.n + 1, pbar.total) + # Convert the output tensor to a PIL image. np_image = output_tensor.detach().numpy().astype(np.uint8) pil_image = Image.fromarray(np_image) return pil_image + def _get_step_callback(self, context: InvocationContext) -> Callable[[int, int], None]: + invocation_type = self.get_type() + + def step_callback(step: int, total_steps: int) -> None: + context.util.signal_progress( + name=invocation_type, + step=step, + total_steps=total_steps, + message="Processing image", + ) + + return step_callback + @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to @@ -156,7 +177,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): assert isinstance(spandrel_model, SpandrelImageToImageModel) # Upscale the image - pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) + pil_image = self.upscale_image( + image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) + ) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) @@ -202,7 +225,9 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): assert isinstance(spandrel_model, SpandrelImageToImageModel) # First pass of upscaling. Note: `pil_image` will be mutated. - pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) + pil_image = self.upscale_image( + image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) + ) # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model # upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions @@ -213,7 +238,13 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): # This is an upscale model, so we should keep upscaling until we reach the target size. iterations = 1 while pil_image.width < target_width or pil_image.height < target_height: - pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled) + pil_image = self.upscale_image( + pil_image, + self.tile_size, + spandrel_model, + context.util.is_canceled, + self._get_step_callback(context), + ) iterations += 1 # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.