InvokeAI/invokeai/app/util/step_callback.py
2023-04-04 13:32:34 +10:00

44 lines
1.4 KiB
Python

import torch
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
def fast_latents_step_callback(
sample: torch.Tensor,
step: int,
steps: int,
id: str,
context: InvocationContext,
):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
else:
return fast_latents_step_callback(*cb_args, **kwargs)