From 29a590ccedb262404b1c4eb7d73251184c7eac78 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 20 Jul 2023 18:45:54 +0300 Subject: [PATCH 1/2] Add sdxl generation preview --- invokeai/app/invocations/sdxl.py | 47 +++++++++++++++ invokeai/app/util/step_callback.py | 93 +++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index c091c8e49b..f877b22924 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union from pydantic import Field, validator from ...backend.model_management import ModelType, SubModelType +from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) @@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation): }, } + def dispatch_progress( + self, + context: InvocationContext, + source_node_id: str, + sample, + step, + total_steps, + ) -> None: + stable_diffusion_xl_step_callback( + context=context, + node=self.dict(), + source_node_id=source_node_id, + sample=sample, + step=step, + total_steps=total_steps, + ) + # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id + ) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] latents = context.services.latents.get(self.noise.latents_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) @@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() + self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) #if callback is not None and i % callback_steps == 0: # callback(i, t, latents) else: @@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() + self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) #if callback is not None and i % callback_steps == 0: # callback(i, t, latents) @@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): }, } + def dispatch_progress( + self, + context: InvocationContext, + source_node_id: str, + sample, + step, + total_steps, + ) -> None: + stable_diffusion_xl_step_callback( + context=context, + node=self.dict(), + source_node_id=source_node_id, + sample=sample, + step=step, + total_steps=total_steps, + ) + # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id + ) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] latents = context.services.latents.get(self.latents.latents_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) @@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() + self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) #if callback is not None and i % callback_steps == 0: # callback(i, t, latents) else: @@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() + self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) #if callback is not None and i % callback_steps == 0: # callback(i, t, latents) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index b4b9a25909..1e8939b0bf 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,9 +1,30 @@ +import torch +from PIL import Image from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.image import ProgressImage from ..invocations.baseinvocation import InvocationContext from ...backend.util.util import image_to_dataURL from ...backend.generator.base import Generator from ...backend.stable_diffusion import PipelineIntermediateState +from invokeai.app.services.config import InvokeAIAppConfig + + +def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None): + latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors + + if smooth_matrix is not None: + latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2) + latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1) + latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0) + + latents_ubyte = ( + ((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte() + ).cpu() + + return Image.fromarray(latents_ubyte.numpy()) def stable_diffusion_step_callback( @@ -37,7 +58,24 @@ def stable_diffusion_step_callback( # step = intermediate_state.step # TODO: only output a preview image when requested - image = Generator.sample_to_lowres_estimated_image(sample) + + # origingally adapted from code by @erucipe and @keturn here: + # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 + + # these updated numbers for v1.5 are from @torridgristle + v1_5_latent_rgb_factors = torch.tensor( + [ + # R G B + [0.3444, 0.1385, 0.0670], # L1 + [0.1247, 0.4027, 0.1494], # L2 + [-0.3192, 0.2513, 0.2103], # L3 + [-0.1307, -0.1874, -0.7445], # L4 + ], + dtype=sample.dtype, + device=sample.device, + ) + + image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) (width, height) = image.size width *= 8 @@ -53,3 +91,56 @@ def stable_diffusion_step_callback( step=intermediate_state.step, total_steps=node["steps"], ) + +def stable_diffusion_xl_step_callback( + context: InvocationContext, + node: dict, + source_node_id: str, + sample, + step, + total_steps, +): + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException + + sdxl_latent_rgb_factors = torch.tensor( + [ + # R G B + [ 0.3816, 0.4930, 0.5320], + [-0.3753, 0.1631, 0.1739], + [ 0.1770, 0.3588, -0.2048], + [-0.4350, -0.2644, -0.4289], + ], + dtype=sample.dtype, + device=sample.device, + ) + + sdxl_smooth_matrix = torch.tensor( + [ + #[ 0.0478, 0.1285, 0.0478], + #[ 0.1285, 0.2948, 0.1285], + #[ 0.0478, 0.1285, 0.0478], + [0.0358, 0.0964, 0.0358], + [0.0964, 0.4711, 0.0964], + [0.0358, 0.0964, 0.0358], + ], + dtype=sample.dtype, + device=sample.device, + ) + + image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + + context.services.events.emit_generator_progress( + graph_execution_state_id=context.graph_execution_state_id, + node=node, + source_node_id=source_node_id, + progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), + step=step, + total_steps=total_steps, + ) \ No newline at end of file From 4a0774b260c73c11d50ab98b8cccbd006aa10cc6 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 20 Jul 2023 18:54:51 +0300 Subject: [PATCH 2/2] Use scale from vae --- invokeai/app/invocations/latent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b4c3454c88..6082057bd3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -764,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation): dtype=vae.dtype ) # FIXME: uses torch.randn. make reproducible! - latents = 0.18215 * latents + latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) name = f"{context.graph_execution_state_id}__{self.id}"