mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import torch
|
|
from PIL import Image
|
|
from ..invocations.baseinvocation import InvocationContext
|
|
from ...backend.util.util import image_to_dataURL
|
|
from ...backend.generator.base import Generator
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
|
|
class CanceledException(Exception):
|
|
pass
|
|
|
|
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
|
# TODO: only output a preview image when requested
|
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
|
|
|
(width, height) = image.size
|
|
width *= 8
|
|
height *= 8
|
|
|
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
|
|
context.services.events.emit_generator_progress(
|
|
context.graph_execution_state_id,
|
|
id,
|
|
{
|
|
"width": width,
|
|
"height": height,
|
|
"dataURL": dataURL
|
|
},
|
|
step,
|
|
steps,
|
|
)
|
|
|
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|
"""
|
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
|
This adapter grabs the needed data and passes it along to the callback function.
|
|
"""
|
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
|
progress_state: PipelineIntermediateState = cb_args[0]
|
|
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
|
else:
|
|
return fast_latents_step_callback(*cb_args, **kwargs)
|