Add sdxl generation preview (#3862)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:


## Description
Add progress preview for sdxl generation nodes
This commit is contained in:
Lincoln Stein 2023-07-20 12:21:57 -04:00 committed by GitHub
commit ddf7ddc2c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 2 deletions

View File

@ -764,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation):
dtype=vae.dtype dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible! ) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"

View File

@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
}, },
} }
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on # based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name) latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
else: else:
@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
}, },
} }
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on # based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
else: else:
@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)

View File

@ -1,9 +1,30 @@
import torch
from PIL import Image
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
def stable_diffusion_step_callback( def stable_diffusion_step_callback(
@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
# step = intermediate_state.step # step = intermediate_state.step
# TODO: only output a preview image when requested # TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
step=intermediate_state.step, step=intermediate_state.step,
total_steps=node["steps"], total_steps=node["steps"],
) )
def stable_diffusion_xl_step_callback(
context: InvocationContext,
node: dict,
source_node_id: str,
sample,
step,
total_steps,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[ 0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[ 0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
#[ 0.0478, 0.1285, 0.0478],
#[ 0.1285, 0.2948, 0.1285],
#[ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=step,
total_steps=total_steps,
)