mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add sdxl generation preview (#3862)
## What type of PR is this? (check all applicable) - [ ] Refactor - [x] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description Add progress preview for sdxl generation nodes
This commit is contained in:
commit
ddf7ddc2c1
@ -764,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
dtype=vae.dtype
|
dtype=vae.dtype
|
||||||
) # FIXME: uses torch.randn. make reproducible!
|
) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
|
||||||
latents = 0.18215 * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import Field, validator
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
source_node_id: str,
|
||||||
|
sample,
|
||||||
|
step,
|
||||||
|
total_steps,
|
||||||
|
) -> None:
|
||||||
|
stable_diffusion_xl_step_callback(
|
||||||
|
context=context,
|
||||||
|
node=self.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
sample=sample,
|
||||||
|
step=step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.noise.latents_name)
|
latents = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
source_node_id: str,
|
||||||
|
sample,
|
||||||
|
step,
|
||||||
|
total_steps,
|
||||||
|
) -> None:
|
||||||
|
stable_diffusion_xl_step_callback(
|
||||||
|
context=context,
|
||||||
|
node=self.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
sample=sample,
|
||||||
|
step=step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
@ -1,9 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
from ...backend.generator.base import Generator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
||||||
|
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||||
|
|
||||||
|
if smooth_matrix is not None:
|
||||||
|
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||||
|
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
||||||
|
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||||
|
|
||||||
|
latents_ubyte = (
|
||||||
|
((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()
|
||||||
|
).cpu()
|
||||||
|
|
||||||
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_step_callback(
|
def stable_diffusion_step_callback(
|
||||||
@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
|
|||||||
# step = intermediate_state.step
|
# step = intermediate_state.step
|
||||||
|
|
||||||
# TODO: only output a preview image when requested
|
# TODO: only output a preview image when requested
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
|
||||||
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
|
||||||
|
# these updated numbers for v1.5 are from @torridgristle
|
||||||
|
v1_5_latent_rgb_factors = torch.tensor(
|
||||||
|
[
|
||||||
|
# R G B
|
||||||
|
[0.3444, 0.1385, 0.0670], # L1
|
||||||
|
[0.1247, 0.4027, 0.1494], # L2
|
||||||
|
[-0.3192, 0.2513, 0.2103], # L3
|
||||||
|
[-0.1307, -0.1874, -0.7445], # L4
|
||||||
|
],
|
||||||
|
dtype=sample.dtype,
|
||||||
|
device=sample.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||||
|
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
width *= 8
|
width *= 8
|
||||||
@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
|
|||||||
step=intermediate_state.step,
|
step=intermediate_state.step,
|
||||||
total_steps=node["steps"],
|
total_steps=node["steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def stable_diffusion_xl_step_callback(
|
||||||
|
context: InvocationContext,
|
||||||
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
|
sample,
|
||||||
|
step,
|
||||||
|
total_steps,
|
||||||
|
):
|
||||||
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
sdxl_latent_rgb_factors = torch.tensor(
|
||||||
|
[
|
||||||
|
# R G B
|
||||||
|
[ 0.3816, 0.4930, 0.5320],
|
||||||
|
[-0.3753, 0.1631, 0.1739],
|
||||||
|
[ 0.1770, 0.3588, -0.2048],
|
||||||
|
[-0.4350, -0.2644, -0.4289],
|
||||||
|
],
|
||||||
|
dtype=sample.dtype,
|
||||||
|
device=sample.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
sdxl_smooth_matrix = torch.tensor(
|
||||||
|
[
|
||||||
|
#[ 0.0478, 0.1285, 0.0478],
|
||||||
|
#[ 0.1285, 0.2948, 0.1285],
|
||||||
|
#[ 0.0478, 0.1285, 0.0478],
|
||||||
|
[0.0358, 0.0964, 0.0358],
|
||||||
|
[0.0964, 0.4711, 0.0964],
|
||||||
|
[0.0358, 0.0964, 0.0358],
|
||||||
|
],
|
||||||
|
dtype=sample.dtype,
|
||||||
|
device=sample.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
|
node=node,
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
|
step=step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user