fix(nodes): fix step_callback + fast latents generation

this depends on the small change in #2957
This commit is contained in:
psychedelicious 2023-03-15 23:50:26 +11:00
parent 5347c12fed
commit 67f8f222d9
2 changed files with 32 additions and 9 deletions

View File

@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
float(step) / float(self.steps),
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache

View File

@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@ -23,8 +26,9 @@ class EventServiceBase:
self,
graph_execution_state_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
percent: float,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
percent=percent,
total_steps=total_steps,
),
)