From 29a590ccedb262404b1c4eb7d73251184c7eac78 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Thu, 20 Jul 2023 18:45:54 +0300
Subject: [PATCH] Add sdxl generation preview

---
 invokeai/app/invocations/sdxl.py   | 47 +++++++++++++++
 invokeai/app/util/step_callback.py | 93 +++++++++++++++++++++++++++++-
 2 files changed, 139 insertions(+), 1 deletion(-)

diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py
index c091c8e49b..f877b22924 100644
--- a/invokeai/app/invocations/sdxl.py
+++ b/invokeai/app/invocations/sdxl.py
@@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
 from pydantic import Field, validator
 
 from ...backend.model_management import ModelType, SubModelType
+from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
 from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
                              InvocationConfig, InvocationContext)
 
@@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
             },
         }
 
+    def dispatch_progress(
+        self,
+        context: InvocationContext,
+        source_node_id: str,
+        sample,
+        step,
+        total_steps,
+    ) -> None:
+        stable_diffusion_xl_step_callback(
+            context=context,
+            node=self.dict(),
+            source_node_id=source_node_id,
+            sample=sample,
+            step=step,
+            total_steps=total_steps,
+        )
+
     # based on
     # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
     @torch.no_grad()
     def invoke(self, context: InvocationContext) -> LatentsOutput:
+        graph_execution_state = context.services.graph_execution_manager.get(
+            context.graph_execution_state_id
+        )
+        source_node_id = graph_execution_state.prepared_source_mapping[self.id]
         latents = context.services.latents.get(self.noise.latents_name)
 
         positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
                         # call the callback, if provided
                         if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                             progress_bar.update()
+                            self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
                             #if callback is not None and i % callback_steps == 0:
                             #    callback(i, t, latents)
             else:
@@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
                         # call the callback, if provided
                         if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                             progress_bar.update()
+                            self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
                             #if callback is not None and i % callback_steps == 0:
                             #    callback(i, t, latents)
 
@@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
             },
         }
 
+    def dispatch_progress(
+        self,
+        context: InvocationContext,
+        source_node_id: str,
+        sample,
+        step,
+        total_steps,
+    ) -> None:
+        stable_diffusion_xl_step_callback(
+            context=context,
+            node=self.dict(),
+            source_node_id=source_node_id,
+            sample=sample,
+            step=step,
+            total_steps=total_steps,
+        )
+
     # based on
     # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
     @torch.no_grad()
     def invoke(self, context: InvocationContext) -> LatentsOutput:
+        graph_execution_state = context.services.graph_execution_manager.get(
+            context.graph_execution_state_id
+        )
+        source_node_id = graph_execution_state.prepared_source_mapping[self.id]
         latents = context.services.latents.get(self.latents.latents_name)
 
         positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
                         # call the callback, if provided
                         if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                             progress_bar.update()
+                            self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
                             #if callback is not None and i % callback_steps == 0:
                             #    callback(i, t, latents)
             else:
@@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
                         # call the callback, if provided
                         if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                             progress_bar.update()
+                            self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
                             #if callback is not None and i % callback_steps == 0:
                             #    callback(i, t, latents)
 
diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py
index b4b9a25909..1e8939b0bf 100644
--- a/invokeai/app/util/step_callback.py
+++ b/invokeai/app/util/step_callback.py
@@ -1,9 +1,30 @@
+import torch
+from PIL import Image
 from invokeai.app.models.exceptions import CanceledException
 from invokeai.app.models.image import ProgressImage
 from ..invocations.baseinvocation import InvocationContext
 from ...backend.util.util import image_to_dataURL
 from ...backend.generator.base import Generator
 from ...backend.stable_diffusion import PipelineIntermediateState
+from invokeai.app.services.config import InvokeAIAppConfig
+
+
+def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
+    latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
+
+    if smooth_matrix is not None:
+        latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
+        latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
+        latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
+
+    latents_ubyte = (
+        ((latent_image + 1) / 2)
+        .clamp(0, 1)  # change scale from -1..1 to 0..1
+        .mul(0xFF)  # to 0..255
+        .byte()
+    ).cpu()
+
+    return Image.fromarray(latents_ubyte.numpy())
 
 
 def stable_diffusion_step_callback(
@@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
     #     step = intermediate_state.step
 
     # TODO: only output a preview image when requested
-    image = Generator.sample_to_lowres_estimated_image(sample)
+
+    # origingally adapted from code by @erucipe and @keturn here:
+    # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
+
+    # these updated numbers for v1.5 are from @torridgristle
+    v1_5_latent_rgb_factors = torch.tensor(
+        [
+            #    R        G        B
+            [0.3444, 0.1385, 0.0670],  # L1
+            [0.1247, 0.4027, 0.1494],  # L2
+            [-0.3192, 0.2513, 0.2103],  # L3
+            [-0.1307, -0.1874, -0.7445],  # L4
+        ],
+        dtype=sample.dtype,
+        device=sample.device,
+    )
+
+    image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
 
     (width, height) = image.size
     width *= 8
@@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
         step=intermediate_state.step,
         total_steps=node["steps"],
     )
+
+def stable_diffusion_xl_step_callback(
+    context: InvocationContext,
+    node: dict,
+    source_node_id: str,
+    sample,
+    step,
+    total_steps,
+):
+    if context.services.queue.is_canceled(context.graph_execution_state_id):
+        raise CanceledException
+
+    sdxl_latent_rgb_factors = torch.tensor(
+        [
+            #   R        G        B
+            [ 0.3816,  0.4930,  0.5320],
+            [-0.3753,  0.1631,  0.1739],
+            [ 0.1770,  0.3588, -0.2048],
+            [-0.4350, -0.2644, -0.4289],
+        ],
+        dtype=sample.dtype,
+        device=sample.device,
+    )
+
+    sdxl_smooth_matrix = torch.tensor(
+        [
+            #[ 0.0478,  0.1285,  0.0478],
+            #[ 0.1285,  0.2948,  0.1285],
+            #[ 0.0478,  0.1285,  0.0478],
+            [0.0358, 0.0964, 0.0358],
+            [0.0964, 0.4711, 0.0964],
+            [0.0358, 0.0964, 0.0358],
+        ],
+        dtype=sample.dtype,
+        device=sample.device,
+    )
+
+    image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
+
+    (width, height) = image.size
+    width *= 8
+    height *= 8
+
+    dataURL = image_to_dataURL(image, image_format="JPEG")
+
+    context.services.events.emit_generator_progress(
+        graph_execution_state_id=context.graph_execution_state_id,
+        node=node,
+        source_node_id=source_node_id,
+        progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
+        step=step,
+        total_steps=total_steps,
+    )
\ No newline at end of file