mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
|
from ..invocations.baseinvocation import InvocationContext
|
|
from ...backend.util.util import image_to_dataURL
|
|
from ...backend.generator.base import Generator
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
|
|
def fast_latents_step_callback(
|
|
sample: torch.Tensor,
|
|
step: int,
|
|
steps: int,
|
|
id: str,
|
|
context: InvocationContext,
|
|
):
|
|
# TODO: only output a preview image when requested
|
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
|
|
|
(width, height) = image.size
|
|
width *= 8
|
|
height *= 8
|
|
|
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
|
|
context.services.events.emit_generator_progress(
|
|
context.graph_execution_state_id,
|
|
id,
|
|
{"width": width, "height": height, "dataURL": dataURL},
|
|
step,
|
|
steps,
|
|
)
|
|
|
|
|
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|
"""
|
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
|
This adapter grabs the needed data and passes it along to the callback function.
|
|
"""
|
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
|
progress_state: PipelineIntermediateState = cb_args[0]
|
|
return fast_latents_step_callback(
|
|
progress_state.latents, progress_state.step, **kwargs
|
|
)
|
|
else:
|
|
return fast_latents_step_callback(*cb_args, **kwargs)
|